-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Prevent JIT from overspecializing to every single size configuration #10844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
%6 : Float(4!, 4) = aten::expand(%2, %3, %4) | ||
%7 : Float(4, 4) = prim::FusionGroup_0[device=0](%6, %0, %5) | ||
return (%7); | ||
graph(%0 : Float(*, *) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
// - Associativity: A simple visual proof is that you can expand 3 tensors | ||
// at the same time by stacking their sizes (with alignment to the right), | ||
// just as you'd do in the case of 2 tensors, but with an intermediate | ||
// (the algorithm ends up being pretty much the same). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
// Proof. A simple exercise for the reader :) | ||
// | ||
// Theorem. If all (pre-concat-)outputs have equal shapes, then we can push the expands to | ||
// (pre-chunk-)inputs, and have all intermediates of the same shape |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
// Lemma 4. Expands can be collapsed, i.e. E(E(x, s1), s2) = E(x, B(s1, s2)). | ||
// Proof. A simple exercise for the reader :) | ||
// | ||
// Theorem. If all (pre-concat-)outputs have equal shapes, then we can push the expands to |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/type.h
Outdated
@@ -183,22 +248,22 @@ struct TORCH_API TensorType : public Type { | |||
} | |||
static TypePtr fromNumberType(TypePtr typ); | |||
|
|||
static CompleteTensorTypePtr sliceSubtypes(const CompleteTensorTypePtr& type) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -63,39 +63,30 @@ IValue representativeValue(Value* v) { | |||
|
|||
void PropagateShapeOnBlock(Block * block, bool insert_expands=true); | |||
|
|||
// for each node in the schema with type Tensor, extract the TensorType | |||
// for each node in the schema with type Tensor, extract the CompleteTensorType | |||
// returns at::nullopt if any Tensor in the schema does not have a known shape | |||
// ignores non-tensor in the list of inputs |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -117,57 +116,123 @@ using TensorTypePtr = std::shared_ptr<TensorType>; | |||
// This node represents a single Tensor value with a specific size |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -49,6 +55,82 @@ std::vector<bool> TensorDesc::findContiguous( | |||
return cont; | |||
} | |||
|
|||
// Descriptor for chunk-ing an input tensor into subtensors | |||
// OR concat-ing an output tensor from subtensors | |||
struct PartitionDesc { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -28,7 +36,7 @@ struct TensorDesc { | |||
: TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} | |||
TensorDesc(const at::Tensor& t) | |||
: TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} | |||
TensorDesc(TensorTypePtr type) | |||
TensorDesc(CompleteTensorTypePtr type) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.h
Outdated
// an output is actually a concatenation of | ||
// many subtensors that the fusion group produces | ||
std::vector<PartitionDesc> concat_desc; | ||
struct FusedKernelCache { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
auto uses = input->uses(); | ||
if (uses.size() == 1) { | ||
Node *user = uses[0].user; | ||
if (user->kind() == prim::FusedChunk) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
chunk_desc.emplace_back(); | ||
flat_inputs.emplace_back(p, agraph.input_desc[input_index++]); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
outputs.clear(); | ||
outputs.reserve(outputDescriptors().size()); | ||
for(auto & od : outputDescriptors()) { | ||
outputs.push_back(torch::getType(backend(),od.scalar_type).tensor()); | ||
outputs.push_back(ref_type.toScalarType(od.scalar_type).tensor()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
InterpreterState(fallback_code).runOneStage(stack); | ||
} | ||
|
||
void FusedKernelCache::expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
} | ||
|
||
at::optional<std::vector<int64_t>> FusedKernelCache::getMapSize(at::TensorList args, at::IntList arg_subset) { | ||
// NB: we leave this uninitialized, because an empty size is trivially |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (chunk_desc.nSubtensors == 1) { | ||
try { | ||
map_size = at::infer_size(map_size, arg.sizes()); | ||
} catch (std::exception& e) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
int64_t num_chunks = chunk_desc.nSubtensors; | ||
int64_t dim = chunk_desc.dim; | ||
if (dim < 0) { | ||
dim += arg.dim(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (!arg.sizes().equals(map_size)) { | ||
arg = arg.expand(map_size); | ||
} | ||
map_size.at(pdesc.dim) /= pdesc.nSubtensors; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return std::all_of(tensors.begin(), tensors.end(), [&expected](Value *v) { | ||
auto actual = v->type()->cast<TensorType>(); | ||
return actual && actual->sizes() == expected->sizes(); | ||
}); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -898,7 +805,7 @@ struct GraphFuser { | |||
Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind()); | |||
chunked_op->copyAttributes(*producer_for_chunk_node); | |||
// Invariant: mappable operators always produce contiguous output |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Question: Shouldn't |
bool PropagateCompleteShapeOnNode( | ||
Node * node, bool insert_expands, std::vector<CompleteTensorTypePtr> types); | ||
|
||
void PropagateCatShape(Node * cat_node) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work.
We could track contiguity in
|
Without symbolic sizes, we can't really track stride contracts in a useful way, but ops definitely have contiguity contracts which users know about, because contiguity tells you if you can view() a tensor, e.g. But you're right, let's add it if/when a pass actually desperately wants to know about contiguity. |
9b85a71
to
f5efae2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good -- I have a bunch of individual but minor comments and questions.
Higher level notes for building on top of this:
- This puts the FusionCompiler at the breaking point of complexity. Further functionality added there is going require refactor some of the Chunk and Concat logic into a separate phase, given how many times we are need to check nSubtensors, and do other things to derive correct sizes. The core of the fusion is simple, but this chunk/concat stuff is getting spread all over the place.
- We may notice regressions from overhead in launching fused kernels. For reference, we know the time it takes to do chunk in the interpreter adds significant overhead, and that seems on the same order of magnitude as the extra checking added here. We will need to monitor this and optimize if necessary.
@@ -52,27 +52,21 @@ bool isDifferentiable(Node * n) { | |||
return true; | |||
|
|||
if (n->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { | |||
return static_cast<bool>(n->input(1)->type()->cast<TensorType>()); | |||
return static_cast<bool>(n->input(1)->type()->cast<CompleteTensorType>()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/export.cpp
Outdated
} else if (kind == TypeKind::TensorType) { | ||
type_proto->set_denotation("TensorType"); | ||
TensorTypePtr node_type = type->cast<TensorType>(); | ||
} else if (kind == TypeKind::CompleteTensorType) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -225,6 +225,9 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) { | |||
} else if (kind == "TensorType") { | |||
// TODO: Don't use DynamicType here | |||
return DynamicType::get(); | |||
} else if (kind == "CompleteTensorType") { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/interpreter.cpp
Outdated
@@ -719,8 +719,8 @@ struct InterpreterStateImpl { | |||
current_pc = pc; | |||
current_stage++; | |||
} | |||
const TensorType & tensorTypeForInput(size_t i) const { | |||
return *function->preprocess.stage_input_types.at(current_stage).at(i)->expect<TensorType>(); | |||
const CompleteTensorType & tensorTypeForInput(size_t i) const { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -150,8 +150,8 @@ void BatchMMBlock(Block* block) { | |||
std::unordered_map<Node*, TreeToken> tokens; | |||
for (auto node : block->nodes()) { | |||
if (node->kind() == aten::mm && | |||
node->input(0)->type()->cast<TensorType>() && | |||
node->input(1)->type()->cast<TensorType>()) { | |||
node->input(0)->type()->cast<CompleteTensorType>() && |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (!arg.sizes().equals(map_size)) { | ||
arg = arg.expand(map_size); | ||
} | ||
map_size.at(pdesc.dim) /= pdesc.nSubtensors; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
for(auto & i : inputs) { | ||
agraph.input_desc.emplace_back(i); | ||
agraph.input_desc = spec.descs(); | ||
at::optional<at::ScalarType> scalar_type; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
throw std::runtime_error("cannot compile a CUDA fusion group, CUDA is not enabled."); | ||
#endif | ||
} else { | ||
JIT_ASSERT(compiler.canCompileOnCPU()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -252,6 +301,14 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { | |||
} | |||
return; | |||
} | |||
// NB: We assume that all shapes are known within fused kernels |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
} | ||
|
||
if (canPropagateShapeByRunningIt(node)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Most of the changes needed to be applied to the fuser, which heavily relied on this info. It now includes some extra run-time shape checks to determine if it can use its fused kernels, or if it should fall back to regular execution instead.
38605bf
to
acc2435
Compare
The last commit forces slicing on types if the dynamic cast succeeds, but the kind doesn't match exactly. Normally we could simply use the copy constructor in this place, but because we're incorrectly comparing only addresses of types in many cases (because we assume they're used as singletons), the slicing needs some extra care. We really should either stop using shared pointers to types in most places (that would also come with a benefit of not having to incref/decref just to check the type), or have a subclass of shared pointer that uses equality on held elements in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
* upstream/master: (89 commits) move HeatmapMaxKeypointOp unittest to oss fix xfails involving literals (pytorch#10905) Bag of Distributions doc fixes (pytorch#10894) Remove FIXME_zerol() from test_jit.py (pytorch#10900) Increase BC for PackedSequence ctor (pytorch#9864) Remove ability of Scalars to hold Tensors. Begin a bestiary of MSVC/NVCC bugs. (pytorch#10883) Prevent JIT from overspecializing to every single size configuration (pytorch#10844) Handling failing test on ROCm. Update mobile predictor caller's interface Cache isContiguous and numel Create class constant for string literal 'blob_names' Conv BN fusion for 3D conv (pytorch#10239) Stop using symbolic override for tracing RNNs (pytorch#10638) Add registry to pybind_state (pytorch#10759) Remove the nanopb submodule Create at::linear (pytorch#10799) Refactor THCNumerics and add common math functions for at::Half (pytorch#10301) Remove Tensor constructor of Scalar. (pytorch#10852) Revert D9492561: [pytorch][PR] Moving the operator argument to the front for kernelPointwiseApply. ...
…ytorch#10844) Summary: Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details. Summary of changes: - Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail. - Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`. - Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`. - Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`. zdevito ezyang mruberry ngimel csarofeen Pull Request resolved: pytorch#10844 Differential Revision: D9498705 Pulled By: apaszke fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details.
Summary of changes:
TensorType
toCompleteTensorType
. Added a newTensorType
which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to useCompleteTensorType
less, as most passes will only have limited information available. To make transition easiercomplete_type->cast<TensorType>()
works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail.ArgumentSpec
toCompleteArgumentSpec
. Added a newArgumentSpec
, which matches argument only at the level of the newTensorType
.CompleteTensorType
andTensorType
.graph_fuser.cpp
.@zdevito @ezyang @mruberry @ngimel @csarofeen