@@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
1189
1189
LogicalResult matchAndRewrite (NewOp op,
1190
1190
PatternRewriter &rewriter) const override {
1191
1191
Location loc = op.getLoc ();
1192
- const auto dstTp = getSparseTensorType (op.getResult ());
1193
- const auto encDst = dstTp .getEncoding ();
1194
- if (!dstTp .hasEncoding () || getCOOStart (encDst ) == 0 )
1192
+ auto stt = getSparseTensorType (op.getResult ());
1193
+ auto enc = stt .getEncoding ();
1194
+ if (!stt .hasEncoding () || getCOOStart (enc ) == 0 )
1195
1195
return failure ();
1196
1196
1197
1197
// Implement the NewOp as follows:
1198
1198
// %orderedCoo = sparse_tensor.new %filename
1199
1199
// %t = sparse_tensor.convert %orderedCoo
1200
+ // with enveloping reinterpreted_map ops for non-permutations.
1201
+ RankedTensorType dstTp = stt.getRankedTensorType ();
1200
1202
RankedTensorType cooTp = getCOOType (dstTp, /* ordered=*/ true );
1201
1203
Value cooTensor = rewriter.create <NewOp>(loc, cooTp, op.getSource ());
1202
- Value convert = rewriter.replaceOpWithNewOp <ConvertOp>(
1203
- op, dstTp.getRankedTensorType (), cooTensor);
1204
+ Value convert = cooTensor;
1205
+ if (!stt.isPermutation ()) { // demap coo, demap dstTp
1206
+ auto coo = getSparseTensorType (cooTensor).getEncoding ().withoutDimToLvl ();
1207
+ convert = rewriter.create <ReinterpretMapOp>(loc, coo, convert);
1208
+ dstTp = getSparseTensorType (convert).withEncoding (enc.withoutDimToLvl ());
1209
+ }
1210
+ convert = rewriter.create <ConvertOp>(loc, dstTp, convert);
1211
+ if (!stt.isPermutation ()) // remap to original enc
1212
+ convert = rewriter.create <ReinterpretMapOp>(loc, enc, convert);
1213
+ rewriter.replaceOp (op, convert);
1204
1214
1205
- // Release the ordered COO tensor.
1215
+ // Release the temporary ordered COO tensor.
1206
1216
rewriter.setInsertionPointAfterValue (convert);
1207
1217
rewriter.create <DeallocTensorOp>(loc, cooTensor);
1208
1218
1209
1219
return success ();
1210
1220
}
1211
1221
};
1212
1222
1223
+ // / Sparse rewriting rule for the out operator.
1213
1224
struct OutRewriter : public OpRewritePattern <OutOp> {
1214
1225
using OpRewritePattern::OpRewritePattern;
1215
1226
LogicalResult matchAndRewrite (OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
1250
1261
primaryTypeFunctionSuffix (eltTp)};
1251
1262
Value value = genAllocaScalar (rewriter, loc, eltTp);
1252
1263
ModuleOp module = op->getParentOfType <ModuleOp>();
1264
+
1253
1265
// For each element in the source tensor, output the element.
1254
1266
rewriter.create <ForeachOp>(
1255
1267
loc, src, std::nullopt,
0 commit comments