Skip to content

Commit 0201e99

Browse files
committed
common : use cpp-httplib as a cURL alternative for downloads
The existing cURL implementation is intentionally left untouched to prevent any regressions and to allow for safe, side-by-side testing by toggling the `LLAMA_CURL` CMake option. Signed-off-by: Adrien Gallouët <[email protected]>
1 parent 60dbbce commit 0201e99

File tree

2 files changed

+366
-8
lines changed

2 files changed

+366
-8
lines changed

common/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,16 @@ if (LLAMA_CURL)
8787
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
8888
include_directories(${CURL_INCLUDE_DIRS})
8989
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
90-
endif ()
90+
else()
91+
find_package(OpenSSL)
92+
if (OpenSSL_FOUND)
93+
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
94+
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
95+
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
96+
else()
97+
message(STATUS "OpenSSL not found, SSL support disabled")
98+
endif()
99+
endif()
91100

92101
if (LLAMA_LLGUIDANCE)
93102
include(ExternalProject)

common/arg.cpp

Lines changed: 356 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
#if defined(LLAMA_USE_CURL)
3838
#include <curl/curl.h>
3939
#include <curl/easy.h>
40+
#else
41+
#include <cpp-httplib/httplib.h>
4042
#endif
4143

4244
#ifdef __linux__
@@ -572,17 +574,364 @@ bool common_has_curl() {
572574
return false;
573575
}
574576

575-
static bool common_download_file_single_online(const std::string &, const std::string &, const std::string &) {
576-
LOG_ERR("error: built without CURL, cannot download model from internet\n");
577-
return false;
577+
struct common_url {
578+
std::string scheme;
579+
std::string user;
580+
std::string password;
581+
std::string host;
582+
std::string path;
583+
};
584+
585+
static common_url parse_url(const std::string & url) {
586+
common_url parts;
587+
auto scheme_end = url.find("://");
588+
589+
if (scheme_end == std::string::npos) {
590+
throw std::runtime_error("invalid URL: no scheme");
591+
}
592+
parts.scheme = url.substr(0, scheme_end);
593+
594+
if (parts.scheme != "http" && parts.scheme != "https") {
595+
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
596+
}
597+
598+
auto rest = url.substr(scheme_end + 3);
599+
auto at_pos = rest.find('@');
600+
601+
if (at_pos != std::string::npos) {
602+
auto auth = rest.substr(0, at_pos);
603+
auto colon_pos = auth.find(':');
604+
if (colon_pos != std::string::npos) {
605+
parts.user = auth.substr(0, colon_pos);
606+
parts.password = auth.substr(colon_pos + 1);
607+
} else {
608+
parts.user = auth;
609+
}
610+
rest = rest.substr(at_pos + 1);
611+
}
612+
613+
auto slash_pos = rest.find('/');
614+
615+
if (slash_pos != std::string::npos) {
616+
parts.host = rest.substr(0, slash_pos);
617+
parts.path = rest.substr(slash_pos);
618+
} else {
619+
parts.host = rest;
620+
parts.path = "/";
621+
}
622+
return parts;
623+
}
624+
625+
static std::pair<httplib::Client, common_url> http_client(const std::string & url) {
626+
common_url parts = parse_url(url);
627+
628+
if (parts.host.empty()) {
629+
throw std::runtime_error("error: invalid URL format");
630+
}
631+
632+
if (!parts.user.empty()) {
633+
throw std::runtime_error("error: user:password@ not supported yet"); // TODO
634+
}
635+
636+
httplib::Client cli(parts.scheme + "://" + parts.host);
637+
cli.set_follow_location(true);
638+
639+
// TODO cert
640+
641+
return { std::move(cli), std::move(parts) };
642+
}
643+
644+
static std::string show_masked_url(const common_url & parts) {
645+
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
646+
}
647+
648+
static void print_progress(size_t current, size_t total) { // TODO isatty
649+
if (!total) {
650+
return;
651+
}
652+
653+
size_t width = 50;
654+
size_t pct = (100 * current) / total;
655+
size_t pos = (width * current) / total;
656+
657+
std::cout << "["
658+
<< std::string(pos, '=')
659+
<< (pos < width ? ">" : "")
660+
<< std::string(width - pos, ' ')
661+
<< "] " << std::setw(3) << pct << "% ("
662+
<< current / (1024 * 1024) << " MB / "
663+
<< total / (1024 * 1024) << " MB)\r";
664+
std::cout.flush();
665+
}
666+
667+
struct common_file_metadata {
668+
std::string etag;
669+
std::string last_modified;
670+
};
671+
672+
static std::optional<common_file_metadata> read_metadata(const std::string & path) {
673+
if (!std::filesystem::exists(path)) {
674+
return std::nullopt;
675+
}
676+
677+
nlohmann::json metadata_json;
678+
common_file_metadata metadata;
679+
680+
std::ifstream metadata_in(path);
681+
try {
682+
metadata_in >> metadata_json;
683+
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(),
684+
metadata_json.dump().c_str());
685+
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
686+
metadata.etag = metadata_json.at("etag");
687+
}
688+
if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) {
689+
metadata.last_modified = metadata_json.at("lastModified");
690+
}
691+
} catch (const nlohmann::json::exception & e) {
692+
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what());
693+
return std::nullopt;
694+
}
695+
696+
return metadata;
697+
}
698+
699+
static void write_metadata(const std::string & path,
700+
const std::string & url,
701+
const common_file_metadata & metadata) {
702+
nlohmann::json metadata_json = {
703+
{ "url", url },
704+
{ "etag", metadata.etag },
705+
{ "lastModified", metadata.last_modified }
706+
};
707+
708+
write_file(path, metadata_json.dump(4));
709+
LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str());
710+
}
711+
712+
static bool common_pull_file(httplib::Client & cli,
713+
const std::string & resolve_path,
714+
const std::string & path_tmp,
715+
bool supports_ranges,
716+
size_t existing_size,
717+
size_t & total_size) {
718+
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
719+
if (!ofs.is_open()) {
720+
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
721+
return false;
722+
}
723+
724+
httplib::Headers headers;
725+
if (supports_ranges && existing_size > 0) {
726+
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
727+
}
728+
729+
std::atomic<size_t> downloaded{existing_size};
730+
731+
auto res = cli.Get(resolve_path, headers,
732+
[&](const httplib::Response &response) {
733+
if (existing_size > 0 && response.status != 206) {
734+
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status);
735+
return false;
736+
}
737+
if (existing_size == 0 && response.status != 200) {
738+
LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status);
739+
return false;
740+
}
741+
if (total_size == 0 && response.has_header("Content-Length")) {
742+
try {
743+
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
744+
total_size = existing_size + content_length;
745+
} catch (const std::exception &e) {
746+
LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what());
747+
}
748+
}
749+
return true;
750+
},
751+
[&](const char *data, size_t len) {
752+
ofs.write(data, len);
753+
if (!ofs) {
754+
LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str());
755+
return false;
756+
}
757+
downloaded += len;
758+
print_progress(downloaded, total_size);
759+
return true;
760+
},
761+
nullptr
762+
);
763+
764+
std::cout << "\n";
765+
766+
if (!res) {
767+
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
768+
return false;
769+
}
770+
771+
return true;
578772
}
579773

580-
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params &) {
581-
if (!url.empty()) {
582-
throw std::runtime_error("error: built without CURL, cannot download model from the internet");
774+
// download one single file from remote URL to local path
775+
static bool common_download_file_single_online(const std::string & url,
776+
const std::string & path,
777+
const std::string & bearer_token) {
778+
// If the file exists, check its JSON metadata companion file.
779+
std::string metadata_path = path + ".json";
780+
static const int max_attempts = 3;
781+
static const int retry_delay_seconds = 2;
782+
783+
auto [cli, parts] = http_client(url);
784+
785+
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
786+
if (!bearer_token.empty()) {
787+
default_headers.insert({"Authorization", "Bearer " + bearer_token});
788+
}
789+
cli.set_default_headers(default_headers);
790+
791+
common_file_metadata last;
792+
const bool file_exists = std::filesystem::exists(path);
793+
if (file_exists) {
794+
if (auto opt = read_metadata(metadata_path)) {
795+
last = *opt;
796+
}
797+
} else {
798+
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
799+
}
800+
801+
for (int i = 0; i < max_attempts; ++i) {
802+
auto head = cli.Head(parts.path);
803+
bool head_ok = head && head->status >= 200 && head->status < 300;
804+
if (!head_ok) {
805+
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
806+
if (file_exists) {
807+
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
808+
return true;
809+
}
810+
}
811+
812+
common_file_metadata current;
813+
if (head_ok) {
814+
if (head->has_header("ETag")) {
815+
current.etag = head->get_header_value("ETag");
816+
}
817+
if (head->has_header("Last-Modified")) {
818+
current.last_modified = head->get_header_value("Last-Modified");
819+
}
820+
}
821+
822+
size_t total_size = 0;
823+
if (head_ok && head->has_header("Content-Length")) {
824+
try {
825+
total_size = std::stoull(head->get_header_value("Content-Length"));
826+
} catch (const std::exception& e) {
827+
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
828+
}
829+
}
830+
831+
bool supports_ranges = false;
832+
if (head_ok && head->has_header("Accept-Ranges")) {
833+
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
834+
}
835+
836+
bool should_download_from_scratch = false;
837+
if (head_ok) {
838+
if (!last.etag.empty() && last.etag != current.etag) {
839+
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
840+
last.etag.c_str(), current.etag.c_str());
841+
should_download_from_scratch = true;
842+
} else if (!last.last_modified.empty() && last.last_modified != current.last_modified) {
843+
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
844+
last.last_modified.c_str(), current.last_modified.c_str());
845+
should_download_from_scratch = true;
846+
}
847+
}
848+
849+
if (file_exists) {
850+
if (!should_download_from_scratch) {
851+
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
852+
return true;
853+
}
854+
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
855+
if (remove(path.c_str()) != 0) {
856+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
857+
return false;
858+
}
859+
}
860+
861+
const std::string path_temporary = path + ".downloadInProgress";
862+
size_t existing_size = 0;
863+
864+
if (std::filesystem::exists(path_temporary)) {
865+
if (supports_ranges && !should_download_from_scratch) {
866+
existing_size = std::filesystem::file_size(path_temporary);
867+
} else if (remove(path_temporary.c_str()) != 0) {
868+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
869+
return false;
870+
}
871+
}
872+
873+
// start the download
874+
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
875+
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(),
876+
current.etag.c_str(), current.last_modified.c_str());
877+
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
878+
if (!was_pull_successful) {
879+
if (i + 1 < max_attempts) {
880+
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
881+
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
882+
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
883+
} else {
884+
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
885+
}
886+
887+
continue;
888+
}
889+
890+
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
891+
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
892+
return false;
893+
}
894+
write_metadata(metadata_path, url, current);
895+
break;
896+
}
897+
898+
return true;
899+
}
900+
901+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
902+
const common_remote_params & params) {
903+
auto [cli, parts] = http_client(url);
904+
905+
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
906+
for (const auto & header : params.headers) {
907+
size_t pos = header.find(':');
908+
if (pos != std::string::npos) {
909+
headers.emplace(header.substr(0, pos), header.substr(pos + 1));
910+
} else {
911+
headers.emplace(header, "");
912+
}
913+
}
914+
915+
if (params.timeout > 0) {
916+
cli.set_read_timeout(params.timeout, 0);
917+
cli.set_write_timeout(params.timeout, 0);
918+
}
919+
920+
std::vector<char> buf;
921+
auto res = cli.Get(parts.path, headers,
922+
[&](const char *data, size_t len) {
923+
buf.insert(buf.end(), data, data + len);
924+
return params.max_size == 0 ||
925+
buf.size() <= static_cast<size_t>(params.max_size);
926+
},
927+
nullptr
928+
);
929+
930+
if (!res) {
931+
throw std::runtime_error("error: cannot make GET request");
583932
}
584933

585-
return {};
934+
return { res->status, std::move(buf) };
586935
}
587936

588937
#endif // LLAMA_USE_CURL

0 commit comments

Comments
 (0)