Skip to content

Conversation

csarofeen
Copy link
Owner

There was a big issue with predicate generation, where to unroll a large set of pointwise ops we would generate a very redundant predicate:

  if ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T4.size[0] * T4.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T4.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T0.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T4.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T4.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T8.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T0.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T12.size[0] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T4.size[0] * T4.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T4.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T12.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T8.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T4.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T0.size[0] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T4.size[0] * T4.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T8.size[0] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T0.size[0] * T0.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T12.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T12.size[1] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T12.size[0] * T12.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T0.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T8.size[1] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T4.size[0] * T4.size[1] ) ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T8.size[0] * T8.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T8.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T4.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T0.size[1] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T8.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T12.size[0] ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) % T12.size[1] ) < T12.size[1] ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T12.size[0] * T12.size[1] ) ) ) 
&& ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T0.size[0] * T0.size[1] ) ) ) 
&& ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T0.size[0] ) ) )

instead of generating a simplified but complete predicate:

 if ( ( ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) / T12.size[1] ) < T8.size[0] )
 && ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < ( T12.size[0] * T12.size[1] ) ) ) )

@csarofeen csarofeen requested a review from naoyam August 21, 2020 15:53
Copy link
Collaborator

@kevinstephano kevinstephano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be an issue for another time as the data member already exists but there seems to be inconsistent usage of the underscore suffix for private data members. p2c_root_map seems to not have the suffix.

@csarofeen csarofeen merged commit 4ab4110 into 20_8_18_devel Aug 21, 2020
@csarofeen
Copy link
Owner Author

Having the member as p2c_root_map_ probably would have made it easier to spot, but I could have still easily made the same mistake with it.

@csarofeen csarofeen deleted the predicate_fix branch June 9, 2021 13:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants