Skip to content

Vectorized welford #2204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 21, 2022
Merged

Vectorized welford #2204

merged 16 commits into from
Nov 21, 2022

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Nov 19, 2022

Apply loop-invariant code hoisting to serial WelfordOps. For example, when the innermost loop looks like:

for () {
  if (pred) {
    welfordCombine(...);
  }
}

The count input should be invariant when the loop is not a reduction loop, and if the predicate is also loop invariant, then this can be transformed as:

After:

nvfuser_index_t new_count = outN()[0] + 1;
float reciprocal = pred ? 1 / new_count : 0;
for () {
  welfordVectorized(..., new_count, reciprocal);
}

This is meant to optimize outer welford reductions. With this, FusionComputeWith6 looks like:

 for(nvfuser_index_t i186 = 0; i186 < (ceilDiv(16, 4)); ++i186) {
      if ((((((((((i186 + nvfuser_zero) * 4) + 3) * (ceilDiv((ceilDiv(((T0.size[0] * T0.size[1]) * T0.size[2]), 16)), 16))) + ((nvfuser_index_t)blockIdx.y)) * 16) + ((nvfuser_index_t)threadIdx.y)) < ((T0.size[0] * T0.size[1]) * T0.size[2])) && (((((((i192 * 2) + ((nvfuser_index_t)blockIdx.x)) * 8) + ((nvfuser_index_t)threadIdx.x)) * 2) + 1) < T0.size[3]))) {
        #pragma unroll
        for(nvfuser_index_t i187 = 0; i187 < 4; ++i187) {
          int i284;
          i284 = (i186 * 4) + i187;
          loadGlobalToLocal<__half, 2, false>(&T1[(i284 * 2)],  &T0[((((((((i186 + nvfuser_zero) * 4) + i187) * (ceilDiv((ceilDiv(((T0.size[0] * T0.size[1]) * T0.size[2]), 16)), 16))) + ((nvfuser_index_t)blockIdx.y)) * 16) + ((nvfuser_index_t)threadIdx.y)) * T0.size[3]) + i271]);
          int i1102;
          i1102 = T12[0] + 1;
          float f1103;
          f1103 = (float)(i1102);
          float f1104;
          f1104 = 1 / f1103;
          #pragma unroll
          for(nvfuser_index_t i189 = 0; i189 < 2; ++i189) {
            float T2[1];
            T2[0]
               = __half2float(T1[(i284 * 2) + i189]);
            welfordVectorized<float>(T13[i189], T11[i189], T12[i189], T2[0], f1104, i1102);
          }
        }
      } else {
        #pragma unroll
        for(nvfuser_index_t i187 = 0; i187 < 4; ++i187) {
          int i915;
          i915 = (((((i186 * 4) + (i187 + nvfuser_zero)) * (ceilDiv((ceilDiv(((T0.size[0] * T0.size[1]) * T0.size[2]), 16)), 16))) + ((nvfuser_index_t)blockIdx.y)) * 16) + ((nvfuser_index_t)threadIdx.y);
          int i391;
          i391 = (i186 * 4) + i187;
          if (((i915 < ((T0.size[0] * T0.size[1]) * T0.size[2])) && (((((((i192 * 2) + ((nvfuser_index_t)blockIdx.x)) * 8) + ((nvfuser_index_t)threadIdx.x)) * 2) + 1) < T0.size[3]))) {
            loadGlobalToLocal<__half, 2, false>(&T1[(i391 * 2)],  &T0[((((((((i186 + nvfuser_zero) * 4) + i187) * (ceilDiv((ceilDiv(((T0.size[0] * T0.size[1]) * T0.size[2]), 16)), 16))) + ((nvfuser_index_t)blockIdx.y)) * 16) + ((nvfuser_index_t)threadIdx.y)) * T0.size[3]) + i271]);
          }
          if (((i915 < ((T0.size[0] * T0.size[1]) * T0.size[2])) && (((((((i192 * 2) + ((nvfuser_index_t)blockIdx.x)) * 8) + ((nvfuser_index_t)threadIdx.x)) * 2) + 1) < T0.size[3]))) {
            int i1105;
            i1105 = T12[0] + 1;
            float f1106;
            f1106 = (float)(i1105);
            float f1107;
            f1107 = 1 / f1106;
            #pragma unroll
            for(nvfuser_index_t i189 = 0; i189 < 2; ++i189) {
              float T2[1];
              T2[0]
                 = __half2float(T1[(i391 * 2) + i189]);
              welfordVectorized<float>(T13[i189], T11[i189], T12[i189], T2[0], f1107, i1105);
            }
          }
        }
      }
    }

Lift the predicated count division outside of the innermost loop if that
loop is exactly mapped with vectorized IDs and not a reduction domain.
Targeted to address outer-reduction grid welford tuning
Represents the 32-bit floating-point scalar value. Not supported in
PyTorch, so can't be used as inputs to fusions
@naoyam naoyam marked this pull request as ready for review November 20, 2022 06:11
@naoyam naoyam changed the title [WIP] Vectorized welford Vectorized welford Nov 20, 2022
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor questions, but overall looks fine, marking as approved.

case DataType::Float:
return IrBuilder::create<Double>(DataType::Float);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you added float scalar support, why is there not supposed to be a float entry here? Is it because this can be user facing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did, but then I felt it's sometimes inconvenient to have both Double and Float, so I dropped Float. The dtype field of Double is used to specify its actual type (#2203). It's also the case with Int, which can represent int64_t by Int(DataType::Int) and int by Int(DataType::Int).

// welfordVectorized(..., new_count, reciprocal, p);
// }
void vectorizeWithInlinePredicate(WelfordOp* wop) {
kir::IfThenElse* wop_ite = scope_exprs_.back()->as<kir::IfThenElse>();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you also double check the ite only has one expression in it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check in isVectorizableWelford.

if (!(wop_ite->thenBody().size() == 1 && wop_ite->thenBody().at(0) == wop &&
          wop_ite->elseBody().empty())) {
      return false;
    }

// Predicated case when the innermost loop has no externaly
// visible effect except for the welford outputs
//
// Before:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the before supposed to be the same in this case as the inline predicate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's checked in isVectorizableWelford:

   // The predicate should be either Manual or Inline
    if (wop_ite->predicate()->predicate_type() != PredicateType::Manual &&
        wop_ite->predicate()->predicate_type() != PredicateType::Inline) {
      return false;
    }

registerInsertBeforeInnerMostLoop(reciprocal_expr);
} else {
// Initialize reciprocal as 0;
registerInsertBeforeInnerMostLoop(IrBuilder::create<UnaryOp>(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just be where(pred, 1/new_count_float, 0)?

Copy link
Collaborator Author

@naoyam naoyam Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it doesn't seem to work. where only accepts Val*, so we would need to have 1/new_count_float as a separate expression and use its output as an input to where, which means 1/new_count_float would be executed unconditionally. Since new_count_float can be zero when the predicate is false, this would result in unsafe code. I'm going to leave it as is.

// it is invariant within the loop. The loop index can be replaced
// with what ever value within the loop range since it is
// independent of the loop index.
kir::TensorIndex* hoistCount(kir::TensorIndex* out_N) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if it would be convenient to add a utility to detect which loops a Val is dependent on so we could assert the invariance this assumes (I know it's really just detected earlier). @zasdfgbnm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would be useful. We could probably just do a search of a specific Val in an index math expression.

// Hoist the count TensorIndex out of the innermost loop, assuming
// it is invariant within the loop. The loop index can be replaced
// with what ever value within the loop range since it is
// independent of the loop index.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the loop index have to be replaced? Is it not already just 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, at this point the loop index is used to index the count Val as indexing doesn't know the value of the loop index doesn't matter.

//
// The pattern matching here only works with a simple sequence of
// expressions but should be sufficient as it is likely the most
// common situation.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this could be a nice generic pass like (a future) index hoisting. Seems to me it's a cleanup of unrolling loops, i.e. trying to simplify the else of the unrolled ite. Noting as future work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. This particular implementation is really simplistic as I don't expect nothing complex loop body should appear with a serial WelfordOp, so we would need to consider a wider variation of patterns.

// guaranteed that its predicate can be safely replaced with the
// predicate of the WelfordOp. If not, that doesn't necessarily
// mean the expr predicate is different, but likely not
// worthwhile to consider.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If two expressions are loop mapped, it doesn't necessarily mean they have the same predicate, it just means their predicates are safe to join. This is because of broadcast resolution, the full predicate is a composite of all the root domains covered in the roots of the expressions. I don't see any specific errors in the code, I'm just stating this is how we do the unrolled predicate, accumulate all the predicates from the various expressions to make sure we covered all the root IterDomains.

@naoyam naoyam merged commit 2057f37 into devel Nov 21, 2022
@naoyam
Copy link
Collaborator Author

naoyam commented Nov 21, 2022

@zasdfgbnm Merged the PR, but let me know if you have any comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants