From a572a52d59f610a02e8208702f9a87cb17882f9f Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 3 Jun 2026 14:56:31 +0200 Subject: [PATCH 1/2] Pass weights_only=True to load_state_dict_from_url --- dinov2/hub/backbones.py | 4 ++-- dinov2/hub/cell_dino/backbones.py | 6 ++++-- dinov2/hub/classifiers.py | 4 ++-- dinov2/hub/depthers.py | 6 +++--- dinov2/hub/dinotxt.py | 14 +++++--------- dinov2/hub/utils.py | 8 ++++++++ dinov2/utils/utils.py | 6 +++++- 7 files changed, 29 insertions(+), 19 deletions(-) diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 9204dc629..f57d6bbc4 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -10,7 +10,7 @@ import torch -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url class Weights(Enum): @@ -73,7 +73,7 @@ def _make_dinov2_model( url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" else: url = convert_path_or_url_to_url(weights) - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + state_dict = _safe_load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) model.load_state_dict(state_dict, strict=True) return model diff --git a/dinov2/hub/cell_dino/backbones.py b/dinov2/hub/cell_dino/backbones.py index 0e6b90b02..ce476f822 100644 --- a/dinov2/hub/cell_dino/backbones.py +++ b/dinov2/hub/cell_dino/backbones.py @@ -8,6 +8,8 @@ import torch +from ..utils import _safe_load_state_dict_from_url + class Weights(Enum): CELL_DINO = "CELL-DINO" @@ -57,8 +59,8 @@ def _make_cell_dino_model( if pretrained_path is not None: state_dict = torch.load(pretrained_path, map_location="cpu") else: - pretrained_url is not None - state_dict = torch.hub.load_state_dict_from_url(pretrained_url, map_location="cpu") + assert pretrained_url is not None + state_dict = _safe_load_state_dict_from_url(pretrained_url, map_location="cpu") model.load_state_dict(state_dict, strict=True) return model diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py index 3f0841efa..1ac6b1c0e 100644 --- a/dinov2/hub/classifiers.py +++ b/dinov2/hub/classifiers.py @@ -10,7 +10,7 @@ import torch.nn as nn from .backbones import _make_dinov2_model -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url class Weights(Enum): @@ -43,7 +43,7 @@ def _make_dinov2_linear_classification_head( model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) layers_str = str(layers) if layers == 4 else "" url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + state_dict = _safe_load_state_dict_from_url(url, map_location="cpu") linear_head.load_state_dict(state_dict, strict=True) return linear_head diff --git a/dinov2/hub/depthers.py b/dinov2/hub/depthers.py index f88b7e9a4..fabbbb23e 100644 --- a/dinov2/hub/depthers.py +++ b/dinov2/hub/depthers.py @@ -11,7 +11,7 @@ from .backbones import _make_dinov2_model from .depth import BNHead, DepthEncoderDecoder, DPTHead -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url, CenterPadding class Weights(Enum): @@ -132,7 +132,7 @@ def _make_dinov2_linear_depther( layers_str = str(layers) if layers == 4 else "" weights_str = weights.value.lower() url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" - checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + checkpoint = _safe_load_state_dict_from_url(url, map_location="cpu") if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] model.load_state_dict(state_dict, strict=False) @@ -220,7 +220,7 @@ def _make_dinov2_dpt_depther( if pretrained: weights_str = weights.value.lower() url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" - checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + checkpoint = _safe_load_state_dict_from_url(url, map_location="cpu") if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] model.load_state_dict(state_dict, strict=False) diff --git a/dinov2/hub/dinotxt.py b/dinov2/hub/dinotxt.py index 3578538ce..d4d892617 100644 --- a/dinov2/hub/dinotxt.py +++ b/dinov2/hub/dinotxt.py @@ -7,7 +7,7 @@ import math from .backbones import dinov2_vitl14_reg -from .utils import _DINOV2_BASE_URL +from .utils import _DINOV2_BASE_URL, _safe_load_state_dict_from_url def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l(): @@ -54,14 +54,10 @@ def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l(): model.visual_model.backbone = vision_backbone model.eval() - visual_model_head_state_dict = torch.hub.load_state_dict_from_url( - _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth", - map_location="cpu", - ) - text_model_state_dict = torch.hub.load_state_dict_from_url( - _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth", - map_location="cpu", - ) + url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth", + visual_model_head_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu") + url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth", + text_model_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu") model.visual_model.head.load_state_dict(visual_model_head_state_dict, strict=True) model.text_model.load_state_dict(text_model_state_dict, strict=True) return model diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py index 9c6641404..cc1316193 100644 --- a/dinov2/hub/utils.py +++ b/dinov2/hub/utils.py @@ -14,6 +14,14 @@ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" +def _safe_load_state_dict_from_url(url: str, **kwargs): + # See https://github.com/pytorch/pytorch/releases/tag/v2.1.0 (Misc / #98479) + if torch.__version__ >= (2, 1): + local_kwargs = {**kwargs, "weights_only": True} + else: + local_kwargs = kwargs + return torch.hub.load_state_dict_from_url(url, **local_kwargs) + def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: compact_arch_name = arch_name.replace("_", "")[:4] registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" diff --git a/dinov2/utils/utils.py b/dinov2/utils/utils.py index 68f8e2c3b..ac1b6ca31 100644 --- a/dinov2/utils/utils.py +++ b/dinov2/utils/utils.py @@ -19,7 +19,11 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key): if urlparse(pretrained_weights).scheme: # If it looks like an URL - state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + kwargs = {} + # See https://github.com/pytorch/pytorch/releases/tag/v2.1.0 (Misc / #98479) + if torch.__version__ >= (2, 1): + kwargs["weights_only"] = True + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu", **kwargs) else: state_dict = torch.load(pretrained_weights, map_location="cpu") if checkpoint_key is not None and checkpoint_key in state_dict: From bb4291b81aacefbe7320d015b2a7729c9d27d0d1 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 3 Jun 2026 15:44:29 +0200 Subject: [PATCH 2/2] Lint --- dinov2/hub/backbones.py | 2 -- dinov2/hub/depthers.py | 2 -- dinov2/hub/dinotxt.py | 5 ++--- dinov2/hub/utils.py | 1 + 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index f57d6bbc4..af5fd95d3 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -8,8 +8,6 @@ from typing import Optional, Union from urllib.parse import urlparse -import torch - from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url diff --git a/dinov2/hub/depthers.py b/dinov2/hub/depthers.py index fabbbb23e..3109ec844 100644 --- a/dinov2/hub/depthers.py +++ b/dinov2/hub/depthers.py @@ -7,8 +7,6 @@ from functools import partial from typing import Optional, Tuple, Union -import torch - from .backbones import _make_dinov2_model from .depth import BNHead, DepthEncoderDecoder, DPTHead from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url, CenterPadding diff --git a/dinov2/hub/dinotxt.py b/dinov2/hub/dinotxt.py index d4d892617..5678539f0 100644 --- a/dinov2/hub/dinotxt.py +++ b/dinov2/hub/dinotxt.py @@ -3,7 +3,6 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import torch import math from .backbones import dinov2_vitl14_reg @@ -54,9 +53,9 @@ def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l(): model.visual_model.backbone = vision_backbone model.eval() - url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth", + url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth" visual_model_head_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu") - url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth", + url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth" text_model_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu") model.visual_model.head.load_state_dict(visual_model_head_state_dict, strict=True) model.text_model.load_state_dict(text_model_state_dict, strict=True) diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py index cc1316193..c6fa705f3 100644 --- a/dinov2/hub/utils.py +++ b/dinov2/hub/utils.py @@ -22,6 +22,7 @@ def _safe_load_state_dict_from_url(url: str, **kwargs): local_kwargs = kwargs return torch.hub.load_state_dict_from_url(url, **local_kwargs) + def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: compact_arch_name = arch_name.replace("_", "")[:4] registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""