-
Notifications
You must be signed in to change notification settings - Fork 370
Fix - Model Download Race Condition #2305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -662,6 +662,34 @@ function extractExplicitBackend(loadBody?: Record<string, unknown>): { recipe: s | |
| return null; | ||
| } | ||
|
|
||
| /** | ||
| * Wait for an in-flight model download started by another caller (or tab) | ||
| * instead of starting a second /pull that would race on the server. | ||
| */ | ||
| async function awaitExistingModelDownload(modelName: string): Promise<void> { | ||
| const downloadId = downloadTracker.getStableDownloadId(modelName, 'model'); | ||
| downloadTracker.startServerPolling(); | ||
|
|
||
| const serverDownloads = await downloadTracker.hydrateFromServer(); | ||
| const active = serverDownloads.find( | ||
| item => item.model_name === modelName && | ||
| (item.running === true || item.status === 'downloading'), | ||
| ); | ||
| if (active) { | ||
| downloadTracker.applyServerDownload(active); | ||
| } else if (!downloadTracker.isActive(modelName)) { | ||
| return; | ||
| } | ||
|
|
||
| await waitForServerDownloadTerminal( | ||
| downloadId, | ||
| modelName, | ||
| new AbortController(), | ||
| () => undefined, | ||
| Boolean(active), | ||
| ); | ||
| } | ||
|
|
||
| /** | ||
| * Universal pre-flight check for all inference requests. | ||
| * Ensures backend is installed, model is downloaded, and model is loaded — | ||
|
|
@@ -791,7 +819,12 @@ async function ensureModelReadyInternal( | |
|
|
||
| // Step 5: Pull model if not downloaded (shows in Download Manager) | ||
| if (!isDownloaded) { | ||
| await pullModel(modelName, { declaredSizeGB: modelsData[modelName]?.size }); | ||
| if (downloadTracker.isActive(modelName) || | ||
| await downloadTracker.hasActiveServerDownload(modelName)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pre-flight check has the same exact-match limitation as awaitExistingModelDownload(). If server-side dedupe is intended to be “one download per logical model”, the client check should use the same identity semantics as the server lock, not only the raw UI/request name. |
||
| await awaitExistingModelDownload(modelName); | ||
| } else { | ||
| await pullModel(modelName, { declaredSizeGB: modelsData[modelName]?.size }); | ||
| } | ||
| } | ||
|
|
||
| // Step 6: Load model into memory (merge loadBody if provided) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3137,13 +3137,34 @@ bool ModelManager::is_model_downloaded(const std::string& model_name) { | |
| return false; | ||
| } | ||
|
|
||
| std::shared_ptr<std::mutex> ModelManager::get_model_download_lock(const std::string& model_name) { | ||
| const std::string lock_key = resolve_model_name(model_name); | ||
| std::lock_guard<std::mutex> guard(model_download_locks_mutex_); | ||
| if (auto it = model_download_locks_.find(lock_key); it != model_download_locks_.end()) { | ||
| if (auto existing = it->second.lock()) { | ||
| return existing; | ||
| } | ||
| } | ||
| auto lock = std::make_shared<std::mutex>(); | ||
| model_download_locks_[lock_key] = lock; | ||
| return lock; | ||
| } | ||
|
|
||
| void ModelManager::download_registered_model(const ModelInfo& info, bool do_not_upgrade, DownloadProgressCallback progress_callback) { | ||
| // Cloud models have no local artifacts; "downloading" is a no-op. | ||
| if (info.recipe == "cloud") { | ||
| update_model_in_cache(info.model_name, true); | ||
| return; | ||
| } | ||
|
|
||
| auto model_lock = get_model_download_lock(info.model_name); | ||
| std::lock_guard<std::mutex> download_guard(*model_lock); | ||
|
|
||
| // Another caller may have finished while we waited for the model lock. | ||
| if (do_not_upgrade && is_model_downloaded(info.model_name)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This post-lock re-check only catches cache-first callers. /load still calls download_registered_model(info) with the default do_not_upgrade=false, so a /load request queued behind an in-flight /pull can wait on this mutex and then still proceed into the download/update path. Could we either make /load call download_registered_model(info, true) or make this guard explicitly skip when the model became downloaded while waiting? This matters because handle_load() currently downloads missing models with download_registered_model(info) and no do_not_upgrade=true. |
||
| return; | ||
| } | ||
|
|
||
| // Use recipe-specific download paths | ||
| if (info.recipe == "flm") { | ||
| download_from_flm(info.checkpoint(), do_not_upgrade, progress_callback); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3426,8 +3426,33 @@ void Server::handle_pull(const httplib::Request& req, httplib::Response& res) { | |
| // response exactly as before. | ||
| stream_download_operation(res, operation); | ||
| } else { | ||
| // Legacy synchronous mode - blocks until complete | ||
| model_manager_->download_model(model_name, request_json, do_not_upgrade); | ||
| // Legacy synchronous mode - blocks until complete. Route through the | ||
| // shared download job registry so sync /pull deduplicates with any | ||
| // in-flight server-owned or SSE pull for the same model. | ||
| auto operation = [this, model_name, request_json, do_not_upgrade](DownloadProgressCallback progress_cb) { | ||
| model_manager_->download_model(model_name, request_json, do_not_upgrade, progress_cb); | ||
| }; | ||
| auto job = start_download_job("model:" + model_name, "model", model_name, operation); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This job key uses the raw request model_name, while the ModelManager download lock uses resolve_model_name(model_name). Alias vs canonical requests for the same logical model can therefore create separate server job IDs/UI rows even though they serialize lower down. Could we key the download job with the same resolved model name used by get_model_download_lock()? |
||
| join_download_job(job); | ||
|
|
||
| std::string error_message; | ||
| std::string error_code; | ||
| { | ||
| std::lock_guard<std::mutex> lock(downloads_mutex_); | ||
| if (job->status == "error") { | ||
| error_message = job->error; | ||
| if (job->progress.contains("code") && job->progress["code"].is_string()) { | ||
| error_code = job->progress["code"].get<std::string>(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (!error_message.empty()) { | ||
| if (error_code == lemon::kUnknownModelErrorCode) { | ||
| throw lemon::UnknownModelError(error_message); | ||
| } | ||
| throw std::runtime_error(error_message); | ||
| } | ||
|
|
||
| nlohmann::json response = {{"status", "success"}, {"model_name", model_name}}; | ||
| res.set_content(response.dump(), "application/json"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,8 @@ | |
| #include <cctype> | ||
| #include <fstream> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
| #include <mbedtls/md.h> | ||
|
|
||
|
|
@@ -768,11 +770,51 @@ DownloadResult HttpClient::download_attempt(const std::string& url, | |
| return result; | ||
| } | ||
|
|
||
| // Serialize concurrent downloads to the same output path. Without this, two | ||
| // callers (e.g. /pull and /load racing, or duplicate /pull requests) can both | ||
| // write to the same .partial file and corrupt the on-disk bytes. | ||
| struct PathDownloadLockRegistry { | ||
| std::mutex registry_mutex; | ||
| std::unordered_map<std::string, std::weak_ptr<std::mutex>> locks; | ||
|
|
||
| std::shared_ptr<std::mutex> acquire(const std::string& output_path) { | ||
| const std::string key = download_lock_key(output_path); | ||
| std::lock_guard<std::mutex> guard(registry_mutex); | ||
| if (auto it = locks.find(key); it != locks.end()) { | ||
| if (auto existing = it->second.lock()) { | ||
| return existing; | ||
| } | ||
| } | ||
| auto lock = std::make_shared<std::mutex>(); | ||
| locks[key] = lock; | ||
| return lock; | ||
| } | ||
|
|
||
| private: | ||
| static std::string download_lock_key(const std::string& output_path) { | ||
| fs::path path = path_from_utf8(output_path); | ||
| std::error_code ec; | ||
| fs::path normalized = fs::weakly_canonical(path, ec); | ||
| if (ec) { | ||
| normalized = fs::absolute(path, ec); | ||
| if (ec) { | ||
| normalized = path; | ||
| } | ||
| } | ||
| return path_to_utf8(normalized); | ||
| } | ||
| }; | ||
|
|
||
| static PathDownloadLockRegistry g_path_download_locks; | ||
|
|
||
| DownloadResult HttpClient::download_file(const std::string& url, | ||
| const std::string& output_path, | ||
| ProgressCallback callback, | ||
| const std::map<std::string, std::string>& headers, | ||
| const DownloadOptions& options) { | ||
| auto path_lock = g_path_download_locks.acquire(output_path); | ||
| std::lock_guard<std::mutex> path_guard(*path_lock); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This path-level lock is a good last line of defense against .partial corruption. Given that the higher-level model/job dedupe can still miss alias/canonical cases, could we add a regression test that races two callers against the same output path and verifies that the partial file is not concurrently written/corrupted? |
||
|
|
||
| DownloadResult final_result; | ||
| int retry_delay_ms = options.initial_retry_delay_ms; | ||
| const ExpectedHash expected_hash = parse_expected_hash(options); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same alias/canonicalization issue on the client side: this only detects an existing server download if item.model_name exactly equals the caller’s modelName. If another caller started the same model through a different alias, ensureModelReady() can miss the active job and start another /pull. Could we normalize model names here or have the server snapshot expose a canonical model ID to compare against?