From 3f3ff4e82f2087618877af3f36defffa5ed0834c Mon Sep 17 00:00:00 2001 From: Lin Lan Date: Sat, 21 May 2022 21:04:37 +0800 Subject: [PATCH] Remove redundant code in dynamically_batch --- jraph/_src/utils.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/jraph/_src/utils.py b/jraph/_src/utils.py index ab89bf5..c5059cc 100644 --- a/jraph/_src/utils.py +++ b/jraph/_src/utils.py @@ -987,31 +987,23 @@ def dynamically_batch( raise RuntimeError('Found graph bigger than batch size. Valid Batch ' f'Size: {batch_size}, Graph Size: {graph_size}') - # If this is the first element of the batch, set it and continue. - # Otherwise check if there is space for the graph in the batch: + # Check if there is space for the graph in the batch: # if there is, add it to the batch # if there isn't, return the old batch and start a new batch. - if not accumulated_graphs: + if ((num_accumulated_graphs + element_graphs > n_graph - 1) or + (num_accumulated_nodes + element_nodes > n_node - 1) or + (num_accumulated_edges + element_edges > n_edge)): + batched_graph = batch_np(accumulated_graphs) + yield pad_with_graphs(batched_graph, n_node, n_edge, n_graph) accumulated_graphs = [element] num_accumulated_nodes = element_nodes num_accumulated_edges = element_edges num_accumulated_graphs = element_graphs - continue else: - if ((num_accumulated_graphs + element_graphs > n_graph - 1) or - (num_accumulated_nodes + element_nodes > n_node - 1) or - (num_accumulated_edges + element_edges > n_edge)): - batched_graph = batch_np(accumulated_graphs) - yield pad_with_graphs(batched_graph, n_node, n_edge, n_graph) - accumulated_graphs = [element] - num_accumulated_nodes = element_nodes - num_accumulated_edges = element_edges - num_accumulated_graphs = element_graphs - else: - accumulated_graphs.append(element) - num_accumulated_nodes += element_nodes - num_accumulated_edges += element_edges - num_accumulated_graphs += element_graphs + accumulated_graphs.append(element) + num_accumulated_nodes += element_nodes + num_accumulated_edges += element_edges + num_accumulated_graphs += element_graphs # We may still have data in batched graph. if accumulated_graphs: