Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions dinov2/hub/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from typing import Optional, Union
from urllib.parse import urlparse

import torch

from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url


class Weights(Enum):
Expand Down Expand Up @@ -73,7 +71,7 @@ def _make_dinov2_model(
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
else:
url = convert_path_or_url_to_url(weights)
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash)
state_dict = _safe_load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash)
model.load_state_dict(state_dict, strict=True)

return model
Expand Down
6 changes: 4 additions & 2 deletions dinov2/hub/cell_dino/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import torch

from ..utils import _safe_load_state_dict_from_url


class Weights(Enum):
CELL_DINO = "CELL-DINO"
Expand Down Expand Up @@ -57,8 +59,8 @@ def _make_cell_dino_model(
if pretrained_path is not None:
state_dict = torch.load(pretrained_path, map_location="cpu")
else:
pretrained_url is not None
state_dict = torch.hub.load_state_dict_from_url(pretrained_url, map_location="cpu")
assert pretrained_url is not None
state_dict = _safe_load_state_dict_from_url(pretrained_url, map_location="cpu")
model.load_state_dict(state_dict, strict=True)

return model
Expand Down
4 changes: 2 additions & 2 deletions dinov2/hub/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn

from .backbones import _make_dinov2_model
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url


class Weights(Enum):
Expand Down Expand Up @@ -43,7 +43,7 @@ def _make_dinov2_linear_classification_head(
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
layers_str = str(layers) if layers == 4 else ""
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
state_dict = _safe_load_state_dict_from_url(url, map_location="cpu")
linear_head.load_state_dict(state_dict, strict=True)

return linear_head
Expand Down
8 changes: 3 additions & 5 deletions dinov2/hub/depthers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from functools import partial
from typing import Optional, Tuple, Union

import torch

from .backbones import _make_dinov2_model
from .depth import BNHead, DepthEncoderDecoder, DPTHead
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, _safe_load_state_dict_from_url, CenterPadding


class Weights(Enum):
Expand Down Expand Up @@ -132,7 +130,7 @@ def _make_dinov2_linear_depther(
layers_str = str(layers) if layers == 4 else ""
weights_str = weights.value.lower()
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
checkpoint = _safe_load_state_dict_from_url(url, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
model.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -220,7 +218,7 @@ def _make_dinov2_dpt_depther(
if pretrained:
weights_str = weights.value.lower()
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
checkpoint = _safe_load_state_dict_from_url(url, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
model.load_state_dict(state_dict, strict=False)
Expand Down
15 changes: 5 additions & 10 deletions dinov2/hub/dinotxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import math

from .backbones import dinov2_vitl14_reg
from .utils import _DINOV2_BASE_URL
from .utils import _DINOV2_BASE_URL, _safe_load_state_dict_from_url


def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l():
Expand Down Expand Up @@ -54,14 +53,10 @@ def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l():
model.visual_model.backbone = vision_backbone
model.eval()

visual_model_head_state_dict = torch.hub.load_state_dict_from_url(
_DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth",
map_location="cpu",
)
text_model_state_dict = torch.hub.load_state_dict_from_url(
_DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth",
map_location="cpu",
)
url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth"
visual_model_head_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu")
url = _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth"
text_model_state_dict = _safe_load_state_dict_from_url(url, map_location="cpu")
model.visual_model.head.load_state_dict(visual_model_head_state_dict, strict=True)
model.text_model.load_state_dict(text_model_state_dict, strict=True)
return model
Expand Down
9 changes: 9 additions & 0 deletions dinov2/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"


def _safe_load_state_dict_from_url(url: str, **kwargs):
# See https://github.com/pytorch/pytorch/releases/tag/v2.1.0 (Misc / #98479)
if torch.__version__ >= (2, 1):
local_kwargs = {**kwargs, "weights_only": True}
else:
local_kwargs = kwargs
return torch.hub.load_state_dict_from_url(url, **local_kwargs)


def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
compact_arch_name = arch_name.replace("_", "")[:4]
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
Expand Down
6 changes: 5 additions & 1 deletion dinov2/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
if urlparse(pretrained_weights).scheme: # If it looks like an URL
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
kwargs = {}
# See https://github.com/pytorch/pytorch/releases/tag/v2.1.0 (Misc / #98479)
if torch.__version__ >= (2, 1):
kwargs["weights_only"] = True
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu", **kwargs)
else:
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
Expand Down
Loading