|
| 1 | +//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "Utils/CodegenUtils.h" |
| 10 | + |
| 11 | +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 12 | +#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| 13 | +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 14 | +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 15 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | +#include "llvm/Support/FormatVariadic.h" |
| 17 | + |
| 18 | +using namespace mlir; |
| 19 | +using namespace sparse_tensor; |
| 20 | + |
| 21 | +//===----------------------------------------------------------------------===// |
| 22 | +// Helper methods. |
| 23 | +//===----------------------------------------------------------------------===// |
| 24 | + |
| 25 | +// TODO: reuse StorageLayout::foreachField? |
| 26 | + |
| 27 | +// TODO: we need COO AoS and SoA |
| 28 | + |
| 29 | +// Convert type range to new types range, with sparse tensors externalized. |
| 30 | +void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, |
| 31 | + SmallVectorImpl<Type> *extraTypes = nullptr) { |
| 32 | + for (auto type : types) { |
| 33 | + // All "dense" data passes through unmodified. |
| 34 | + if (!getSparseTensorEncoding(type)) { |
| 35 | + convTypes.push_back(type); |
| 36 | + continue; |
| 37 | + } |
| 38 | + // Convert the external representation of the values array. |
| 39 | + const SparseTensorType stt(cast<RankedTensorType>(type)); |
| 40 | + auto shape = {ShapedType::kDynamic}; |
| 41 | + auto vtp = RankedTensorType::get(shape, stt.getElementType()); |
| 42 | + convTypes.push_back(vtp); |
| 43 | + if (extraTypes) |
| 44 | + extraTypes->push_back(vtp); |
| 45 | + // Convert the external representations of the pos/crd arrays. |
| 46 | + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
| 47 | + const auto lt = stt.getLvlType(lvl); |
| 48 | + if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
| 49 | + auto ptp = RankedTensorType::get(shape, stt.getPosType()); |
| 50 | + auto ctp = RankedTensorType::get(shape, stt.getCrdType()); |
| 51 | + convTypes.push_back(ptp); |
| 52 | + convTypes.push_back(ctp); |
| 53 | + if (extraTypes) { |
| 54 | + extraTypes->push_back(ptp); |
| 55 | + extraTypes->push_back(ctp); |
| 56 | + } |
| 57 | + } else { |
| 58 | + assert(isDenseLT(lt)); // TODO: handle other cases |
| 59 | + } |
| 60 | + } |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +// Convert input and output values to [dis[assemble ops for sparse tensors. |
| 65 | +void convVals(OpBuilder &builder, Location loc, TypeRange types, |
| 66 | + ValueRange fromVals, ValueRange extraVals, |
| 67 | + SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) { |
| 68 | + unsigned idx = 0; |
| 69 | + for (auto type : types) { |
| 70 | + // All "dense" data passes through unmodified. |
| 71 | + if (!getSparseTensorEncoding(type)) { |
| 72 | + toVals.push_back(fromVals[idx++]); |
| 73 | + continue; |
| 74 | + } |
| 75 | + // Convert the external representation of the values array. |
| 76 | + auto rtp = cast<RankedTensorType>(type); |
| 77 | + const SparseTensorType stt(rtp); |
| 78 | + auto shape = {ShapedType::kDynamic}; |
| 79 | + SmallVector<Value> inputs; |
| 80 | + SmallVector<Type> retTypes; |
| 81 | + SmallVector<Type> cntTypes; |
| 82 | + // Collect the external representation of the values array for |
| 83 | + // input or the outgoing sparse tensor for output. |
| 84 | + inputs.push_back(fromVals[idx++]); |
| 85 | + if (!isIn) { |
| 86 | + inputs.push_back(extraVals[extra++]); |
| 87 | + retTypes.push_back(RankedTensorType::get(shape, stt.getElementType())); |
| 88 | + cntTypes.push_back(builder.getIndexType()); |
| 89 | + } |
| 90 | + // Collect the external representations of the pos/crd arrays. |
| 91 | + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
| 92 | + const auto lt = stt.getLvlType(lvl); |
| 93 | + if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
| 94 | + if (isIn) { |
| 95 | + inputs.push_back(fromVals[idx++]); |
| 96 | + inputs.push_back(fromVals[idx++]); |
| 97 | + } else { |
| 98 | + Type pTp = stt.getPosType(); |
| 99 | + Type cTp = stt.getCrdType(); |
| 100 | + inputs.push_back(extraVals[extra++]); |
| 101 | + inputs.push_back(extraVals[extra++]); |
| 102 | + retTypes.push_back(RankedTensorType::get(shape, pTp)); |
| 103 | + retTypes.push_back(RankedTensorType::get(shape, cTp)); |
| 104 | + cntTypes.push_back(pTp); |
| 105 | + cntTypes.push_back(cTp); |
| 106 | + } |
| 107 | + } else { |
| 108 | + assert(isDenseLT(lt)); // TODO: handle other cases |
| 109 | + } |
| 110 | + } |
| 111 | + if (isIn) { |
| 112 | + // Assemble multiple inputs into a single sparse tensor. |
| 113 | + auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); |
| 114 | + toVals.push_back(a.getResult()); |
| 115 | + } else { |
| 116 | + // Disassemble a single sparse input into multiple outputs. |
| 117 | + // Note that this includes the counters, which are dropped. |
| 118 | + unsigned len = retTypes.size(); |
| 119 | + retTypes.append(cntTypes); |
| 120 | + auto d = |
| 121 | + builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); |
| 122 | + for (unsigned i = 0; i < len; i++) |
| 123 | + toVals.push_back(d.getResult(i)); |
| 124 | + } |
| 125 | + } |
| 126 | +} |
| 127 | + |
| 128 | +//===----------------------------------------------------------------------===// |
| 129 | +// Rewriting rules. |
| 130 | +//===----------------------------------------------------------------------===// |
| 131 | + |
| 132 | +namespace { |
| 133 | + |
| 134 | +// A rewriting rules that converts public entry methods that use sparse tensors |
| 135 | +// as input parameters and/or output return values into wrapper functions |
| 136 | +// that [dis]assemble the individual tensors that constitute the actual |
| 137 | +// storage used externally into MLIR sparse tensors. |
| 138 | +// |
| 139 | +// In particular, each sparse tensor input |
| 140 | +// |
| 141 | +// void foo(..., t, ...) { } |
| 142 | +// |
| 143 | +// adds the following strucuture in a wrapper |
| 144 | +// |
| 145 | +// void spiface_foo(..., t1..tn, ...) { |
| 146 | +// t = assemble t1..tn |
| 147 | +// foo(..., t, ...) |
| 148 | +// } |
| 149 | +// |
| 150 | +// and likewise, each output tensor |
| 151 | +// |
| 152 | +// ... T ... bar(...) { return ..., t, ...; } |
| 153 | +// |
| 154 | +// adds the following structure in a wrapper |
| 155 | +// |
| 156 | +// ... T1..TN ... spiface_bar(..., t1'..tn') { |
| 157 | +// ..., t, ... = bar(...) |
| 158 | +// t1..tn = disassemble t, t1'..tn' |
| 159 | +// return ..., t1..tn, ... |
| 160 | +// } |
| 161 | +// |
| 162 | +// TODO: refine output sparse tensors to work well with external framework |
| 163 | +// |
| 164 | +// TODO: use "inlining" instead of a wrapper? |
| 165 | +// |
| 166 | +struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { |
| 167 | + using OpRewritePattern::OpRewritePattern; |
| 168 | + |
| 169 | + LogicalResult matchAndRewrite(func::FuncOp funcOp, |
| 170 | + PatternRewriter &rewriter) const override { |
| 171 | + // Only a rewrite an entry with the c-interface requested. |
| 172 | + if (!funcOp->getAttrOfType<UnitAttr>( |
| 173 | + LLVM::LLVMDialect::getEmitCWrapperAttrName())) |
| 174 | + return failure(); |
| 175 | + |
| 176 | + // Translate sparse tensor types to external types. |
| 177 | + SmallVector<Type> inputTypes; |
| 178 | + SmallVector<Type> outputTypes; |
| 179 | + SmallVector<Type> extraTypes; |
| 180 | + convTypes(funcOp.getArgumentTypes(), inputTypes); |
| 181 | + convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); |
| 182 | + |
| 183 | + // Only sparse inputs or outputs need a wrapper function. |
| 184 | + if (inputTypes.size() == funcOp.getArgumentTypes().size() && |
| 185 | + outputTypes.size() == funcOp.getResultTypes().size()) |
| 186 | + return failure(); |
| 187 | + |
| 188 | + // Start the new wrapper function. Together with the c-interface mangling, |
| 189 | + // a sparse external entry point eventually will have a name like: |
| 190 | + // _mlir_ciface_spiface_XXX(...) |
| 191 | + Location loc = funcOp.getLoc(); |
| 192 | + ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); |
| 193 | + MLIRContext *context = modOp.getContext(); |
| 194 | + OpBuilder moduleBuilder(modOp.getBodyRegion()); |
| 195 | + std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str(); |
| 196 | + unsigned extra = inputTypes.size(); |
| 197 | + inputTypes.append(extraTypes); |
| 198 | + auto func = moduleBuilder.create<func::FuncOp>( |
| 199 | + loc, wrapper, FunctionType::get(context, inputTypes, outputTypes)); |
| 200 | + func.setPublic(); |
| 201 | + func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 202 | + UnitAttr::get(context)); |
| 203 | + |
| 204 | + // Construct new wrapper function body. |
| 205 | + auto org = SymbolRefAttr::get(context, funcOp.getName()); |
| 206 | + OpBuilder::InsertionGuard insertionGuard(rewriter); |
| 207 | + Block *body = func.addEntryBlock(); |
| 208 | + rewriter.setInsertionPointToStart(body); |
| 209 | + |
| 210 | + // Convert inputs. |
| 211 | + SmallVector<Value> inputs; |
| 212 | + convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), |
| 213 | + ValueRange(), inputs, 0, /*isIn=*/true); |
| 214 | + |
| 215 | + // Call original function. |
| 216 | + auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, |
| 217 | + inputs); |
| 218 | + |
| 219 | + // Convert outputs and return. |
| 220 | + SmallVector<Value> outputs; |
| 221 | + convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), |
| 222 | + body->getArguments(), outputs, extra, /*isIn=*/false); |
| 223 | + rewriter.create<func::ReturnOp>(loc, outputs); |
| 224 | + |
| 225 | + // Strip the c-interface attribute from the original function. |
| 226 | + funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); |
| 227 | + return success(); |
| 228 | + } |
| 229 | +}; |
| 230 | + |
| 231 | +} // namespace |
| 232 | + |
| 233 | +//===----------------------------------------------------------------------===// |
| 234 | +// Public method for populating conversion rules. |
| 235 | +//===----------------------------------------------------------------------===// |
| 236 | + |
| 237 | +void mlir::populateSparseAssembler(RewritePatternSet &patterns) { |
| 238 | + patterns.add<SparseFuncAssembler>(patterns.getContext()); |
| 239 | +} |
0 commit comments