diff --git a/examples/openai/search_graph_openai.py b/examples/openai/search_graph_openai.py index 9d23ac3b..7f40ebde 100644 --- a/examples/openai/search_graph_openai.py +++ b/examples/openai/search_graph_openai.py @@ -28,7 +28,7 @@ # ************************************************ search_graph = SearchGraph( - prompt="List me the Chioggia typical dishes", + prompt="List me Chioggia's famous dishes", config=graph_config ) diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index 663adc62..8a71319a 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -2,12 +2,18 @@ GraphIterator Module """ -from typing import List, Optional +import asyncio import copy -from tqdm import tqdm +from typing import List, Optional + +from tqdm.asyncio import tqdm + from .base_node import BaseNode +_default_batchsize = 16 + + class GraphIteratorNode(BaseNode): """ A node responsible for instantiating and running multiple graph instances in parallel. @@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode): node_name (str): The unique identifier name for the node, defaulting to "Parse". """ - def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"): + def __init__( + self, + input: str, + output: List[str], + node_config: Optional[dict] = None, + node_name: str = "GraphIterator", + ): super().__init__(node_name, "node", input, output, 2, node_config) - self.verbose = False if node_config is None else node_config.get("verbose", False) + self.verbose = ( + False if node_config is None else node_config.get("verbose", False) + ) - def execute(self, state: dict) -> dict: + def execute(self, state: dict) -> dict: """ Executes the node's logic to instantiate and run multiple graph instances in parallel. @@ -43,37 +57,78 @@ def execute(self, state: dict) -> dict: KeyError: If the input keys are not found in the state, indicating that the necessary information for running the graph instances is missing. """ + batchsize = self.node_config.get("batchsize", _default_batchsize) if self.verbose: - print(f"--- Executing {self.node_name} Node ---") + print(f"--- Executing {self.node_name} Node with batchsize {batchsize} ---") + + try: + eventloop = asyncio.get_event_loop() + except RuntimeError: + eventloop = None + + if eventloop and eventloop.is_running(): + state = eventloop.run_until_complete(self._async_execute(state, batchsize)) + else: + state = asyncio.run(self._async_execute(state, batchsize)) + + return state + + async def _async_execute(self, state: dict, batchsize: int) -> dict: + """asynchronously executes the node's logic with multiple graph instances + running in parallel, using a semaphore of some size for concurrency regulation + + Args: + state: The current state of the graph. + batchsize: The maximum number of concurrent instances allowed. + + Returns: + The updated state with the output key containing the results + aggregated out of all parallel graph instances. - # Interpret input keys based on the provided input expression + Raises: + KeyError: If the input keys are not found in the state. + """ + + # interprets input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys + # fetches data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] urls = input_data[1] graph_instance = self.node_config.get("graph_instance", None) + if graph_instance is None: - raise ValueError("Graph instance is required for graph iteration.") - - # set the prompt and source for each url + raise ValueError("graph instance is required for concurrent execution") + + # sets the prompt for the graph instance graph_instance.prompt = user_prompt - graphs_instances = [] + + participants = [] + + # semaphore to limit the number of concurrent tasks + semaphore = asyncio.Semaphore(batchsize) + + async def _async_run(graph): + async with semaphore: + return await asyncio.to_thread(graph.run) + + # creates a deepcopy of the graph instance for each endpoint for url in urls: - # make a copy of the graph instance - copy_graph_instance = copy.copy(graph_instance) - copy_graph_instance.source = url - graphs_instances.append(copy_graph_instance) - - # run the graph for each url and use tqdm for progress bar - graphs_answers = [] - for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose): - result = graph.run() - graphs_answers.append(result) - - state.update({self.output[0]: graphs_answers}) + instance = copy.copy(graph_instance) + instance.source = url + + participants.append(instance) + + futures = [_async_run(graph) for graph in participants] + + answers = await tqdm.gather( + *futures, desc="processing graph instances", disable=not self.verbose + ) + + state.update({self.output[0]: answers}) + return state