diff --git a/src/libtorch_c/torch_c.cpp b/src/libtorch_c/torch_c.cpp index 591d10b45..5ea4faf80 100644 --- a/src/libtorch_c/torch_c.cpp +++ b/src/libtorch_c/torch_c.cpp @@ -218,10 +218,12 @@ void torchRunModule(ModuleContext* ctx, const char* fnName, } if (ctx->module) { + torch::NoGradGuard guard; torch::jit::script::Method method = ctx->module->get_method(fnName); method.run(stack); } else { + torch::NoGradGuard guard; torch::jit::Function& fn = ctx->cu->get_function(fnName); fn.run(stack); }