diff --git a/samples/common/sampleOptions.cpp b/samples/common/sampleOptions.cpp index 861267f4..ac809828 100644 --- a/samples/common/sampleOptions.cpp +++ b/samples/common/sampleOptions.cpp @@ -1386,6 +1386,10 @@ void BuildOptions::parse(Arguments& arguments) if (getAndDelOption(arguments, "--saveEngine", engine)) { + if (!canWriteFile(engine)) + { + throw std::invalid_argument(std::string("Cannot write engine file to path: ") + engine); + } save = true; } if (load && save) diff --git a/samples/common/sampleUtils.cpp b/samples/common/sampleUtils.cpp index 93be2845..aa3bfbb0 100644 --- a/samples/common/sampleUtils.cpp +++ b/samples/common/sampleUtils.cpp @@ -105,6 +105,36 @@ void loadFromFile(std::string const& fileName, char* dst, size_t size) } } +// Check if the file at the given path can be written to. +bool canWriteFile(const std::string& path) +{ + // Verify that the target directory exists + namespace fs = std::filesystem; + fs::path p(path); + fs::path dir = p.has_parent_path() ? p.parent_path() : fs::current_path(); + if (!fs::exists(dir) || !fs::is_directory(dir)) + { + return false; + } + + // Try creating and writing to a temporary file in the directory + const fs::path tempFilePath = dir / ".writetest.tmp"; + std::ofstream test(tempFilePath.string(), std::ios::out | std::ios::trunc); + if (!test.is_open()) + { + return false; + } + test << "test"; + const bool ok = test.good(); + test.close(); + + // Clean up the temporary file without throwing on failure + std::error_code ec; + fs::remove(tempFilePath, ec); + + return ok; +} + std::vector splitToStringVec(std::string const& s, char separator, int64_t maxSplit) { std::vector splitted; diff --git a/samples/common/sampleUtils.h b/samples/common/sampleUtils.h index 118a336b..535247ee 100644 --- a/samples/common/sampleUtils.h +++ b/samples/common/sampleUtils.h @@ -19,6 +19,7 @@ #define TRT_SAMPLE_UTILS_H #include +#include #include #include #include @@ -75,6 +76,8 @@ void dumpInt4Buffer(void const* buffer, std::string const& separator, std::ostre void loadFromFile(std::string const& fileName, char* dst, size_t size); +bool canWriteFile(const std::string& path); + std::vector splitToStringVec(std::string const& option, char separator, int64_t maxSplit = -1); bool broadcastIOFormats(std::vector const& formats, size_t nbBindings, bool isInput = true);