Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions jraph/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,31 +987,23 @@ def dynamically_batch(
raise RuntimeError('Found graph bigger than batch size. Valid Batch '
f'Size: {batch_size}, Graph Size: {graph_size}')

# If this is the first element of the batch, set it and continue.
# Otherwise check if there is space for the graph in the batch:
# Check if there is space for the graph in the batch:
# if there is, add it to the batch
# if there isn't, return the old batch and start a new batch.
if not accumulated_graphs:
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
(num_accumulated_nodes + element_nodes > n_node - 1) or
(num_accumulated_edges + element_edges > n_edge)):
batched_graph = batch_np(accumulated_graphs)
yield pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
continue
else:
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
(num_accumulated_nodes + element_nodes > n_node - 1) or
(num_accumulated_edges + element_edges > n_edge)):
batched_graph = batch_np(accumulated_graphs)
yield pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
else:
accumulated_graphs.append(element)
num_accumulated_nodes += element_nodes
num_accumulated_edges += element_edges
num_accumulated_graphs += element_graphs
accumulated_graphs.append(element)
num_accumulated_nodes += element_nodes
num_accumulated_edges += element_edges
num_accumulated_graphs += element_graphs

# We may still have data in batched graph.
if accumulated_graphs:
Expand Down