diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 2ed66cc83eefb..b0327cc10e9de 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -234,6 +234,60 @@ class SaveResultOpConversion } }; +template +static mlir::LogicalResult +processReturnLikeOp(OpTy ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) { + auto loc = ret.getLoc(); + rewriter.setInsertionPoint(ret); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->template getParentOfType(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); + } + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp(ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + return mlir::success(); +} + class ReturnOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern { llvm::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { - auto loc = ret.getLoc(); - rewriter.setInsertionPoint(ret); - mlir::Value resultValue = ret.getOperand(0); - fir::LoadOp resultLoad; - mlir::Value resultStorage; - // Identify result local storage. - if (auto load = resultValue.getDefiningOp()) { - resultLoad = load; - resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = resultStorage.getDefiningOp()) - resultStorage = declare.getMemref(); - } - // Replace old local storage with new storage argument, unless - // the derived type is C_PTR/C_FUN_PTR, in which case the return - // type is updated to return void* (no new argument is passed). - if (fir::isa_builtin_cptr_type(resultValue.getType())) { - auto module = ret->getParentOfType(); - FirOpBuilder builder(rewriter, module); - mlir::Value cptr = resultValue; - if (resultLoad) { - // Replace whole derived type load by component load. - cptr = resultLoad.getMemref(); - rewriter.setInsertionPoint(resultLoad); - } - mlir::Value newResultValue = - fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); - newResultValue = builder.createConvert( - loc, getVoidPtrType(ret.getContext()), newResultValue); - rewriter.setInsertionPoint(ret); - rewriter.replaceOpWithNewOp( - ret, mlir::ValueRange{newResultValue}); - } else if (resultStorage) { - resultStorage.replaceAllUsesWith(newArg); - rewriter.replaceOpWithNewOp(ret); - } else { - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result - // storage. at the return point. - rewriter.create(loc, resultValue, newArg); - rewriter.replaceOpWithNewOp(ret); - } - // Delete result old local storage if unused. - if (resultStorage) - if (auto alloc = resultStorage.getDefiningOp()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); - return mlir::success(); + return processReturnLikeOp(ret, newArg, rewriter); + } + +private: + mlir::Value newArg; +}; + +class GPUReturnOpConversion + : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} + llvm::LogicalResult + matchAndRewrite(mlir::gpu::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + return processReturnLikeOp(ret, newArg, rewriter); } private: @@ -373,6 +395,9 @@ class AbstractResultOpt patterns.insert(context, newArg); target.addDynamicallyLegalOp( [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); + patterns.insert(context, newArg); + target.addDynamicallyLegalOp( + [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); assert(func.getFunctionType() == getNewFunctionType(funcTy, shouldBoxResult)); } else { diff --git a/flang/test/Fir/CUDA/cuda-abstract-result.mlir b/flang/test/Fir/CUDA/cuda-abstract-result.mlir new file mode 100644 index 0000000000000..8c59487ca5cd5 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-abstract-result.mlir @@ -0,0 +1,37 @@ +// RUN: fir-opt -pass-pipeline='builtin.module(gpu.module(gpu.func(abstract-result)))' %s | FileCheck %s + +gpu.module @test { + gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref>, %arg1: !fir.ref) -> !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> { + %c1_i32 = arith.constant 1 : i32 + %18 = fir.dummy_scope : !fir.dscope + %19 = fir.declare %arg0 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Ea"} : (!fir.ref>, !fir.dscope) -> !fir.ref> + %20 = fir.declare %arg1 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Eb"} : (!fir.ref, !fir.dscope) -> !fir.ref + %21 = fir.alloca !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {bindc_name = "c", uniq_name = "_QMinterval_mFtest1Ec"} + %22 = fir.declare %21 {uniq_name = "_QMinterval_mFtest1Ec"} : (!fir.ref>) -> !fir.ref> + %23 = fir.alloca i32 {bindc_name = "warpsize", uniq_name = "_QMcudadeviceECwarpsize"} + %24 = fir.declare %23 {uniq_name = "_QMcudadeviceECwarpsize"} : (!fir.ref) -> !fir.ref + %25 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %26 = fir.coordinate_of %19, %25 : (!fir.ref>, !fir.field) -> !fir.ref + %27 = fir.load %20 : !fir.ref + %28 = arith.negf %27 fastmath : f32 + %29 = fir.load %26 : !fir.ref + %30 = fir.call @__fadd_rd(%29, %28) proc_attrs fastmath : (f32, f32) -> f32 + %31 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %32 = fir.coordinate_of %22, %31 : (!fir.ref>, !fir.field) -> !fir.ref + fir.store %30 to %32 : !fir.ref + %33 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %34 = fir.coordinate_of %19, %33 : (!fir.ref>, !fir.field) -> !fir.ref + %35 = fir.load %20 : !fir.ref + %36 = arith.negf %35 fastmath : f32 + %37 = fir.load %34 : !fir.ref + %38 = fir.call @__fadd_ru(%37, %36) proc_attrs fastmath : (f32, f32) -> f32 + %39 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %40 = fir.coordinate_of %22, %39 : (!fir.ref>, !fir.field) -> !fir.ref + fir.store %38 to %40 : !fir.ref + %41 = fir.load %22 : !fir.ref> + gpu.return %41 : !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + } +} + +// CHECK: gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref) { +// CHECK: gpu.return{{$}}