diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 9df5548cd5d93..92ce053ad5c82 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/Support/ADTExtras.h" namespace llvm { class BitVector; @@ -274,20 +275,14 @@ class RankedTensorType::Builder { /// Erase a dim from shape @pos. Builder &dropDim(unsigned pos) { assert(pos < shape.size() && "overflow"); - if (storage.empty()) - storage.append(shape.begin(), shape.end()); - storage.erase(storage.begin() + pos); - shape = {storage.data(), storage.size()}; + shape.erase(pos); return *this; } /// Insert a val into shape @pos. Builder &insertDim(int64_t val, unsigned pos) { assert(pos <= shape.size() && "overflow"); - if (storage.empty()) - storage.append(shape.begin(), shape.end()); - storage.insert(storage.begin() + pos, val); - shape = {storage.data(), storage.size()}; + shape.insert(pos, val); return *this; } @@ -296,9 +291,7 @@ class RankedTensorType::Builder { } private: - ArrayRef shape; - // Owning shape data for copy-on-write operations. - SmallVector storage; + CopyOnWriteArrayRef shape; Type elementType; Attribute encoding; }; @@ -313,27 +306,18 @@ class VectorType::Builder { public: /// Build from another VectorType. explicit Builder(VectorType other) - : shape(other.getShape()), elementType(other.getElementType()), + : elementType(other.getElementType()), shape(other.getShape()), scalableDims(other.getScalableDims()) {} /// Build from scratch. Builder(ArrayRef shape, Type elementType, - unsigned numScalableDims = 0, ArrayRef scalableDims = {}) - : shape(shape), elementType(elementType) { - if (scalableDims.empty()) - scalableDims = SmallVector(shape.size(), false); - else - this->scalableDims = scalableDims; - } + ArrayRef scalableDims = {}) + : elementType(elementType), shape(shape), scalableDims(scalableDims) {} Builder &setShape(ArrayRef newShape, ArrayRef newIsScalableDim = {}) { - if (newIsScalableDim.empty()) - scalableDims = SmallVector(shape.size(), false); - else - scalableDims = newIsScalableDim; - shape = newShape; + scalableDims = newIsScalableDim; return *this; } @@ -345,25 +329,16 @@ class VectorType::Builder { /// Erase a dim from shape @pos. Builder &dropDim(unsigned pos) { assert(pos < shape.size() && "overflow"); - if (storage.empty()) - storage.append(shape.begin(), shape.end()); - if (storageScalableDims.empty()) - storageScalableDims.append(scalableDims.begin(), scalableDims.end()); - storage.erase(storage.begin() + pos); - storageScalableDims.erase(storageScalableDims.begin() + pos); - shape = {storage.data(), storage.size()}; - scalableDims = - ArrayRef(storageScalableDims.data(), storageScalableDims.size()); + shape.erase(pos); + if (!scalableDims.empty()) + scalableDims.erase(pos); return *this; } /// Set a dim in shape @pos to val. Builder &setDim(unsigned pos, int64_t val) { - if (storage.empty()) - storage.append(shape.begin(), shape.end()); - assert(pos < storage.size() && "overflow"); - storage[pos] = val; - shape = {storage.data(), storage.size()}; + assert(pos < shape.size() && "overflow"); + shape.set(pos, val); return *this; } @@ -372,13 +347,9 @@ class VectorType::Builder { } private: - ArrayRef shape; - // Owning shape data for copy-on-write operations. - SmallVector storage; Type elementType; - ArrayRef scalableDims; - // Owning scalableDims data for copy-on-write operations. - SmallVector storageScalableDims; + CopyOnWriteArrayRef shape; + CopyOnWriteArrayRef scalableDims; }; /// Given an `originalShape` and a `reducedShape` assumed to be a subset of diff --git a/mlir/include/mlir/Support/ADTExtras.h b/mlir/include/mlir/Support/ADTExtras.h new file mode 100644 index 0000000000000..1e4708f8f7d3f --- /dev/null +++ b/mlir/include/mlir/Support/ADTExtras.h @@ -0,0 +1,82 @@ +//===- ADTExtras.h - Extra ADTs for use in MLIR -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_ADTEXTRAS_H +#define MLIR_SUPPORT_ADTEXTRAS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// CopyOnWriteArrayRef +//===----------------------------------------------------------------------===// + +// A wrapper around an ArrayRef that copies to a SmallVector on +// modification. This is for use in the mlir::::Builders. +template +class CopyOnWriteArrayRef { +public: + CopyOnWriteArrayRef(ArrayRef array) : nonOwning(array){}; + + CopyOnWriteArrayRef &operator=(ArrayRef array) { + nonOwning = array; + owningStorage = {}; + return *this; + } + + void insert(size_t index, T value) { + SmallVector &vector = ensureCopy(); + vector.insert(vector.begin() + index, value); + } + + void erase(size_t index) { + // Note: A copy can be avoided when just dropping the front/back dims. + if (isNonOwning() && index == 0) { + nonOwning = nonOwning.drop_front(); + } else if (isNonOwning() && index == size() - 1) { + nonOwning = nonOwning.drop_back(); + } else { + SmallVector &vector = ensureCopy(); + vector.erase(vector.begin() + index); + } + } + + void set(size_t index, T value) { ensureCopy()[index] = value; } + + size_t size() const { return ArrayRef(*this).size(); } + + bool empty() const { return ArrayRef(*this).empty(); } + + operator ArrayRef() const { + return nonOwning.empty() ? ArrayRef(owningStorage) : nonOwning; + } + +private: + bool isNonOwning() const { return !nonOwning.empty(); } + + SmallVector &ensureCopy() { + // Empty non-owning storage signals the array has been copied to the owning + // storage (or both are empty). Note: `nonOwning` should never reference + // `owningStorage`. This can lead to dangling references if the + // CopyOnWriteArrayRef is copied. + if (isNonOwning()) { + owningStorage = SmallVector(nonOwning); + nonOwning = {}; + } + return owningStorage; + } + + ArrayRef nonOwning; + SmallVector owningStorage; +}; + +} // namespace mlir + +#endif diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp index 82674fd3768b6..61264bc523648 100644 --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -131,4 +131,99 @@ TEST(ShapedTypeTest, CloneVector) { VectorType::get(vectorNewShape, vectorNewType)); } +TEST(ShapedTypeTest, VectorTypeBuilder) { + MLIRContext context; + Type f32 = FloatType::getF32(&context); + + SmallVector shape{2, 4, 8, 9, 1}; + SmallVector scalableDims{true, false, true, false, false}; + VectorType vectorType = VectorType::get(shape, f32, scalableDims); + + { + // Drop some dims. + VectorType dropFrontTwoDims = + VectorType::Builder(vectorType).dropDim(0).dropDim(0); + ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType()); + ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape()); + ASSERT_EQ(vectorType.getScalableDims().drop_front(2), + dropFrontTwoDims.getScalableDims()); + } + + { + // Set some dims. + VectorType setTwoDims = + VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12); + ASSERT_EQ(setTwoDims.getShape(), ArrayRef({10, 4, 8, 12, 1})); + ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType()); + ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims()); + } + + { + // Test for bug from: + // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a + // Constructs a temporary builder, modifies it, copies it to `builder`. + // This used to lead to a use-after-free. Running under sanitizers will + // catch any issues. + VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16); + VectorType newVectorType = VectorType(builder); + ASSERT_EQ(newVectorType.getDimSize(0), 16); + } + + { + // Make builder from scratch (without scalable dims) -- this use to lead to + // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. + // Running under sanitizers will catch any issues. + SmallVector shape{1, 2, 3, 4}; + VectorType::Builder builder(shape, f32); + ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape)); + } + + { + // Set vector shape (without scalable dims) -- this use to lead to + // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. + // Running under sanitizers will catch any issues. + VectorType::Builder builder(vectorType); + SmallVector newShape{2, 2}; + builder.setShape(newShape); + ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape)); + } +} + +TEST(ShapedTypeTest, RankedTensorTypeBuilder) { + MLIRContext context; + Type f32 = FloatType::getF32(&context); + + SmallVector shape{2, 4, 8, 16, 32}; + RankedTensorType tensorType = RankedTensorType::get(shape, f32); + + { + // Drop some dims. + RankedTensorType dropFrontTwoDims = + RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0); + ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType()); + ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef({16, 32})); + } + + { + // Insert some dims. + RankedTensorType insertTwoDims = + RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3); + ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType()); + ASSERT_EQ(insertTwoDims.getShape(), + ArrayRef({2, 4, 7, 9, 8, 16, 32})); + } + + { + // Test for bug from: + // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a + // Constructs a temporary builder, modifies it, copies it to `builder`. + // This used to lead to a use-after-free. Running under sanitizers will + // catch any issues. + RankedTensorType::Builder builder = + RankedTensorType::Builder(tensorType).dropDim(0); + RankedTensorType newTensorType = RankedTensorType(builder); + ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape()); + } +} + } // namespace