-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadacon_model.py
More file actions
118 lines (92 loc) · 4.26 KB
/
adacon_model.py
File metadata and controls
118 lines (92 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from utils.google_utils import *
from utils.layers import *
from utils.parse_config import *
import utils.extras as extras
from models import Darknet, Backbone, BranchController
from enum import Enum
import numpy as np
class AdaConMode(Enum):
multi_branch = 1
single_branch = 2
oracle = 3
class AdaConYolo(nn.Module):
def __init__(self, model_args, img_size=(416,416), exec_mode=1, multi_branch_thres=0.1):
super(AdaConYolo, self).__init__()
model_args = parse_model_args(model_args)
self.num_classes = int(model_args['num_classes'])
clusters_file = model_args['clusters']
self.clusters = parse_clusters_config(clusters_file)
self.class_to_cluster_list = get_class_to_cluster_map(self.clusters)
self.exec_mode = exec_mode
self.multi_branch_thres = multi_branch_thres
backbone_cfg = model_args['backbone_cfg']
backbone_weights = model_args['backbone_weights']
branch_controller_cfg = model_args['branch_controller_cfg']
branch_controller_weights = None
if 'branch_controller_weights' in model_args:
branch_controller_weights = model_args['branch_controller_weights']
branches_cfg = model_args['branches_cfg']
branches_weights = None
if 'branches_weights' in model_args:
branches_weights = [check_file(f) for f in model_args['branches_weights']]
self.backbone = Backbone(backbone_cfg)
self.backbone.load_darknet_weights(backbone_weights, 100)
self.branches = nn.ModuleList()
for i, cfg in enumerate(branches_cfg):
branch = Darknet(cfg, img_size)
if branches_weights:
branch.load_state_dict(torch.load(branches_weights[i])['model'])
self.branches.append(branch)
self.branch_controller = BranchController(branch_controller_cfg, len(self.clusters))
if branch_controller_weights:
self.branch_controller.load_state_dict(torch.load(branch_controller_weights))
def forward(self, x):
if self.training:
return self._forward_training(x)
else:
return self._forward_testing(x)
def _forward_training(self, x):
back_out = self.backbone(x)
preds = []
for cluster_idx, branch in enumerate(self.branches):
branch_out = branch(back_out, out=self.backbone.layer_outputs)
preds.append(branch_out)
self.backbone.layer_outputs = []
return preds, None
def _forward_testing(self, x):
back_out = self.backbone(x)
if self.exec_mode == AdaConMode.single_branch.value:
active_branches = [torch.argmax(self.branch_controller(back_out, []))]
elif self.exec_mode == AdaConMode.multi_branch.value:
class_out = self.branch_controller(back_out, [])
active_branches = torch.where(class_out > self.multi_branch_thres)[1]
elif self.exec_mode == AdaConMode.oracle.value:
active_branches = np.arange(len(self.branches))
preds = []
for cluster_idx, branch in enumerate(self.branches):
if cluster_idx not in active_branches:
continue
branch_out, _ = branch(back_out, out=self.backbone.layer_outputs)
full_detection = torch.zeros(branch_out.shape[0], branch_out.shape[1], self.num_classes+5, device=x.get_device())
full_detection[:, :, 0:5] = branch_out[:, :, 0:5]
new_indices = [5 + k for k in self.clusters[cluster_idx]]
full_detection[:, :, new_indices] = branch_out[:, :, 5:]
preds.append(full_detection)
self.backbone.layer_outputs = []
return torch.cat(preds, 1)
# def backward(self, losses):
# def _max(losses):
# max_loss = 0
# for i in range(1,len(losses)):
# if losses[i] > losses[max_loss]:
# max_loss = i
# return max_loss
# # i = np.argmax(losses)
# with torch.no_grad():
# i = _max(losses)
# # print(losses, i)
# # for i, loss in enumerate(losses):
# # if i < len(losses) - 1:
# losses[i].backward()
# # else:
# # loss.backward(retain_graph=False)