Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,660 changes: 1,458 additions & 202 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ homepage = "https://github.com/huggingface/text-embeddings-inference"
[workspace.dependencies]
anyhow = "1.0.75"
clap = { version = "4.1", features = ["derive", "env"] }
hf-hub = { version = "0.4", features = ["tokio"], default-features = false }
hf-hub = { version = "1.0.0-rc.1", features = ["rustls-tls"], default-features = false }
metrics = "0.23.1"
nohash-hasher = "0.2"
num_cpus = "1.16.0"
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ memmap2 = "^0.9"
[dev-dependencies]
insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] }
is_close = "0.1.3"
hf-hub = { workspace = true, features = ["ureq"] }
hf-hub = { workspace = true, features = ["blocking"] }
anyhow = { workspace = true }
tokenizers = { workspace = true }
serial_test = { workspace = true }
Expand Down
109 changes: 76 additions & 33 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use anyhow::Result;
use hf_hub::api::sync::{ApiBuilder, ApiError, ApiRepo};
use hf_hub::{Repo, RepoType};
use anyhow::{anyhow, Result};
use hf_hub::{HFClient, HFError, HFRepositorySync, RepoTypeModel};
use insta::internals::YamlMatcher;
use serde::{Deserialize, Serialize};
use std::cmp::max;
Expand Down Expand Up @@ -132,41 +131,52 @@ pub fn download_artifacts(
revision: Option<&'static str>,
dense_path: Option<&'static str>,
) -> Result<(PathBuf, Option<Vec<String>>)> {
let mut builder = ApiBuilder::from_env().with_progress(false);
let mut builder = HFClient::builder();

if let Ok(token) = std::env::var("HF_TOKEN") {
builder = builder.with_token(Some(token));
builder = builder.token(token);
}

if let Some(cache_dir) = std::env::var_os("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
builder = builder.cache_dir(PathBuf::from(cache_dir));
}

let api = builder.build().unwrap();
let api_repo = if let Some(revision) = revision {
api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
))
} else {
api.repo(Repo::new(model_id.to_string(), RepoType::Model))
};

api_repo.get("config.json")?;
api_repo.get("tokenizer.json")?;

let model_files = match download_safetensors(&api_repo) {
let client = builder.build_sync()?;
let (owner, name) = model_id
.split_once('/')
.ok_or_else(|| anyhow!("model_id must be in `owner/name` form, got `{model_id}`"))?;
let repo = client.model(owner, name);
let revision = revision.map(str::to_string);

repo.download_file()
.filename("config.json")
.maybe_revision(revision.clone())
.send()?;
repo.download_file()
.filename("tokenizer.json")
.maybe_revision(revision.clone())
.send()?;

let model_files = match download_safetensors(&repo, revision.clone()) {
Ok(p) => p,
Err(_) => {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
let p = api_repo.get("pytorch_model.bin")?;
let p = repo
.download_file()
.filename("pytorch_model.bin")
.maybe_revision(revision.clone())
.send()?;
vec![p]
}
};

let dense_paths = if let Ok(modules_path) = api_repo.get("modules.json") {
let dense_paths = if let Ok(modules_path) = repo
.download_file()
.filename("modules.json")
.maybe_revision(revision.clone())
.send()
{
match parse_dense_paths_from_modules(&modules_path) {
Ok(paths) => match paths.len() {
0 => None,
Expand All @@ -177,12 +187,12 @@ pub fn download_artifacts(
paths[0].clone()
};

download_dense_module(&api_repo, &path)?;
download_dense_module(&repo, revision.clone(), &path)?;
Some(vec![path])
}
_ => {
for path in &paths {
download_dense_module(&api_repo, path)?;
download_dense_module(&repo, revision.clone(), path)?;
}
Some(paths)
}
Expand All @@ -197,18 +207,30 @@ pub fn download_artifacts(
Ok((model_root, dense_paths))
}

fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
fn download_safetensors(
repo: &HFRepositorySync<RepoTypeModel>,
revision: Option<String>,
) -> Result<Vec<PathBuf>, HFError> {
// Single file
tracing::info!("Downloading `model.safetensors`");
match api.get("model.safetensors") {
match repo
.download_file()
.filename("model.safetensors")
.maybe_revision(revision.clone())
.send()
{
Ok(p) => return Ok(vec![p]),
Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err),
};

// Sharded weights
// Download and parse index file
tracing::info!("Downloading `model.safetensors.index.json`");
let index_file = api.get("model.safetensors.index.json")?;
let index_file = repo
.download_file()
.filename("model.safetensors.index.json")
.maybe_revision(revision.clone())
.send()?;
let index_file_string: String =
std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted");
let json: serde_json::Value = serde_json::from_str(&index_file_string)
Expand All @@ -230,7 +252,12 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let mut safetensors_files = Vec::new();
for n in safetensors_filenames {
tracing::info!("Downloading `{}`", n);
safetensors_files.push(api.get(&n)?);
safetensors_files.push(
repo.download_file()
.filename(&n)
.maybe_revision(revision.clone())
.send()?,
);
}

Ok(safetensors_files)
Expand All @@ -248,20 +275,36 @@ fn parse_dense_paths_from_modules(modules_path: &PathBuf) -> Result<Vec<String>,
.collect::<Vec<String>>())
}

fn download_dense_module(api: &ApiRepo, dense_path: &str) -> Result<PathBuf, ApiError> {
fn download_dense_module(
repo: &HFRepositorySync<RepoTypeModel>,
revision: Option<String>,
dense_path: &str,
) -> Result<PathBuf, HFError> {
let config_file = format!("{}/config.json", dense_path);
tracing::info!("Downloading `{}`", config_file);
let config_path = api.get(&config_file)?;
let config_path = repo
.download_file()
.filename(&config_file)
.maybe_revision(revision.clone())
.send()?;

let safetensors_file = format!("{}/model.safetensors", dense_path);
tracing::info!("Downloading `{}`", safetensors_file);
match api.get(&safetensors_file) {
match repo
.download_file()
.filename(&safetensors_file)
.maybe_revision(revision.clone())
.send()
{
Ok(_) => {}
Err(err) => {
tracing::warn!("Could not download `{}`: {}", safetensors_file, err);
let pytorch_file = format!("{}/pytorch_model.bin", dense_path);
tracing::info!("Downloading `{}`", pytorch_file);
api.get(&pytorch_file)?;
repo.download_file()
.filename(&pytorch_file)
.maybe_revision(revision.clone())
.send()?;
}
}

Expand Down
Loading
Loading