diff --git a/detic/predictor.py b/detic/predictor.py index 047ed80..26a0578 100644 --- a/detic/predictor.py +++ b/detic/predictor.py @@ -23,13 +23,15 @@ def get_clip_embeddings(vocabulary, prompt='a '): cache_file_path = f"/tmp/detic-clip-embeddings-{hash_value}.pt" if Path(cache_file_path).exists(): print(f"loading embeddings for {vocabulary} from {cache_file_path}") - return torch.load(cache_file_path) + return torch.load(cache_file_path, map_location='cpu') else: from detic.modeling.text.text_encoder import build_text_encoder text_encoder = build_text_encoder(pretrain=True) - text_encoder.eval() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + text_encoder.eval().to(device) texts = [prompt + x for x in vocabulary] - emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + with torch.no_grad(): + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() print(f"saved embeddings for {vocabulary} to {cache_file_path}") torch.save(emb, cache_file_path) return emb