Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/openai/search_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# ************************************************

search_graph = SearchGraph(
prompt="List me the Chioggia typical dishes",
prompt="List me Chioggia's famous dishes",
config=graph_config
)

Expand Down
103 changes: 79 additions & 24 deletions scrapegraphai/nodes/graph_iterator_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
GraphIterator Module
"""

from typing import List, Optional
import asyncio
import copy
from tqdm import tqdm
from typing import List, Optional

from tqdm.asyncio import tqdm

from .base_node import BaseNode


_default_batchsize = 16


class GraphIteratorNode(BaseNode):
"""
A node responsible for instantiating and running multiple graph instances in parallel.
Expand All @@ -23,12 +29,20 @@ class GraphIteratorNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""

def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"):
def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "GraphIterator",
):
super().__init__(node_name, "node", input, output, 2, node_config)

self.verbose = False if node_config is None else node_config.get("verbose", False)
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)

def execute(self, state: dict) -> dict:
def execute(self, state: dict) -> dict:
"""
Executes the node's logic to instantiate and run multiple graph instances in parallel.

Expand All @@ -43,37 +57,78 @@ def execute(self, state: dict) -> dict:
KeyError: If the input keys are not found in the state, indicating that the
necessary information for running the graph instances is missing.
"""
batchsize = self.node_config.get("batchsize", _default_batchsize)

if self.verbose:
print(f"--- Executing {self.node_name} Node ---")
print(f"--- Executing {self.node_name} Node with batchsize {batchsize} ---")

try:
eventloop = asyncio.get_event_loop()
except RuntimeError:
eventloop = None

if eventloop and eventloop.is_running():
state = eventloop.run_until_complete(self._async_execute(state, batchsize))
else:
state = asyncio.run(self._async_execute(state, batchsize))

return state

async def _async_execute(self, state: dict, batchsize: int) -> dict:
"""asynchronously executes the node's logic with multiple graph instances
running in parallel, using a semaphore of some size for concurrency regulation

Args:
state: The current state of the graph.
batchsize: The maximum number of concurrent instances allowed.

Returns:
The updated state with the output key containing the results
aggregated out of all parallel graph instances.

# Interpret input keys based on the provided input expression
Raises:
KeyError: If the input keys are not found in the state.
"""

# interprets input keys based on the provided input expression
input_keys = self.get_input_keys(state)

# Fetching data from the state based on the input keys
# fetches data from the state based on the input keys
input_data = [state[key] for key in input_keys]

user_prompt = input_data[0]
urls = input_data[1]

graph_instance = self.node_config.get("graph_instance", None)

if graph_instance is None:
raise ValueError("Graph instance is required for graph iteration.")
# set the prompt and source for each url
raise ValueError("graph instance is required for concurrent execution")

# sets the prompt for the graph instance
graph_instance.prompt = user_prompt
graphs_instances = []

participants = []

# semaphore to limit the number of concurrent tasks
semaphore = asyncio.Semaphore(batchsize)

async def _async_run(graph):
async with semaphore:
return await asyncio.to_thread(graph.run)

# creates a deepcopy of the graph instance for each endpoint
for url in urls:
# make a copy of the graph instance
copy_graph_instance = copy.copy(graph_instance)
copy_graph_instance.source = url
graphs_instances.append(copy_graph_instance)

# run the graph for each url and use tqdm for progress bar
graphs_answers = []
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose):
result = graph.run()
graphs_answers.append(result)

state.update({self.output[0]: graphs_answers})
instance = copy.copy(graph_instance)
instance.source = url

participants.append(instance)

futures = [_async_run(graph) for graph in participants]

answers = await tqdm.gather(
*futures, desc="processing graph instances", disable=not self.verbose
)

state.update({self.output[0]: answers})

return state