Skip to content

Conversation

shmsong
Copy link

@shmsong shmsong commented May 9, 2022

This PR is a quick patch for redundant predicate sync insertion.

A sync is needed for redundant parallel type unless all use chain of the redundantly written value in smem/gmem arrive at redundant write consumers of the same parallel type.

This PR patches the insertion so that all redundant writes are sync'ed to avoid race conditions that may happen in devel TOT.

The detection for the cases where sync is not needed for redundant types will be in a follow up.

@shmsong shmsong changed the base branch from master to devel May 9, 2022 21:01
@shmsong shmsong mentioned this pull request May 9, 2022
4 tasks
@shmsong shmsong changed the title WIP: Patch sync insertion for redundant predicated writes Patch sync insertion for redundant predicated writes May 9, 2022
@shmsong shmsong requested review from naoyam and csarofeen May 9, 2022 22:11
launch_params_.gdimy() * launch_params_.gdimz(),
"Wanted to launch a cooperative kernel, however the number of blocks is greater than ",
"what can be resident on the GPU at once. Need: ",
launch_params_.gdimx() * launch_params_.gdimy() * launch_params_.gdimz(),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated formatting.

Comment on lines +22715 to +22716
tv0->computeAt(tv3, 0);
tv1->computeAt(tv3, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these meant to do something?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Just making sure all the CA parameter has a value. Vaguely remember we didn't have a default behavior without any CA setting but it was a long while ago.

@naoyam
Copy link
Collaborator

naoyam commented May 9, 2022

What I mentioned in the MMA PR was that when we have a chain of redundant exprs, I was wondering if each would be synchronized. I added a variation of the test to see what happens, and here's the generated code:

__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 2> T1, Tensor<float, 2> T3) {
  alignas(16) extern __shared__ char array[];
  unsigned offset = 0;
  offset = alignBufferSize(offset, 16);
  float* T4 = reinterpret_cast<float*>(array + offset);
  offset += (32 * sizeof(float));
  // Alias Allocation - shared
  auto& T2 = T4;
  if ((((nvfuser_index_t)threadIdx.y) == 0)) {
    T4[((nvfuser_index_t)threadIdx.x)]
       = T0[(((nvfuser_index_t)threadIdx.x) * T0.stride[0])];
  }
  __barrier_sync(0);
  if ((((nvfuser_index_t)threadIdx.y) == 0)) {
    T2[((nvfuser_index_t)threadIdx.x)]
       = T4[((nvfuser_index_t)threadIdx.x)];
  }
  __barrier_sync(0);
  T3[(((nvfuser_index_t)threadIdx.y) * 32) + ((nvfuser_index_t)threadIdx.x)]
    = T2[((nvfuser_index_t)threadIdx.x)]
    + T1[(((nvfuser_index_t)threadIdx.y) * T1.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T1.stride[1])];
}

The point I was making is the first sync is redundant.

I guess this is not common, so I think it's fine with this for now, but I wanted to clarify my concern for the future optimization.

@naoyam
Copy link
Collaborator

naoyam commented May 9, 2022

Please remove or merge the added test as you'd like. I just wanted to demonstrate the case.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

@shmsong
Copy link
Author

shmsong commented May 9, 2022

What I mentioned in the MMA PR was that when we have a chain of redundant exprs, I was wondering if each would be synchronized. I added a variation of the test to see what happens, and here's the generated code:

__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 2> T1, Tensor<float, 2> T3) {
  alignas(16) extern __shared__ char array[];
  unsigned offset = 0;
  offset = alignBufferSize(offset, 16);
  float* T4 = reinterpret_cast<float*>(array + offset);
  offset += (32 * sizeof(float));
  // Alias Allocation - shared
  auto& T2 = T4;
  if ((((nvfuser_index_t)threadIdx.y) == 0)) {
    T4[((nvfuser_index_t)threadIdx.x)]
       = T0[(((nvfuser_index_t)threadIdx.x) * T0.stride[0])];
  }
  __barrier_sync(0);
  if ((((nvfuser_index_t)threadIdx.y) == 0)) {
    T2[((nvfuser_index_t)threadIdx.x)]
       = T4[((nvfuser_index_t)threadIdx.x)];
  }
  __barrier_sync(0);
  T3[(((nvfuser_index_t)threadIdx.y) * 32) + ((nvfuser_index_t)threadIdx.x)]
    = T2[((nvfuser_index_t)threadIdx.x)]
    + T1[(((nvfuser_index_t)threadIdx.y) * T1.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T1.stride[1])];
}

The point I was making is the first sync is redundant.

I guess this is not common, so I think it's fine with this for now, but I wanted to clarify my concern for the future optimization.

Yes. I was planning on handling this in a follow up, i.e. a redundant write has a use-chain that has other redundant writes. So T0-> T2 is a redundant chain and T2->T3 isn't.

I think these vertical redundant chains seem to be not too bad to handle. Would need to think a bit more about if there're pathological horizontal patterns, probably worst case we arrive at sub-optimal code without re-considering expr ordering.

@shmsong shmsong force-pushed the patch_sync_insertion branch from 38b27fa to 9637c58 Compare May 9, 2022 23:58
@shmsong
Copy link
Author

shmsong commented May 9, 2022

Please remove or merge the added test as you'd like. I just wanted to demonstrate the case.

@naoyam Thanks for the repro. The new test case moved to #1687 and made it a fail for the redundant sync insertion.

@shmsong shmsong merged commit f9132b7 into devel May 10, 2022
@shmsong shmsong deleted the patch_sync_insertion branch May 10, 2022 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants