-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
407 lines (325 loc) · 14 KB
/
Copy pathevaluate.py
File metadata and controls
407 lines (325 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
DefectFill 评估模块
包含 KID (Kernel Inception Distance) 和 IC-LPIPS (Inter-image Contextual LPIPS) 评估算法
评估指标说明:
- KID: 衡量生成图像与真实图像的分布距离(质量),越小越好
- IC-LPIPS: 衡量生成图像之间的感知差异(多样性),越大越好
"""
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import lpips
from PIL import Image
from torchvision import transforms, models
from datetime import datetime
import argparse
from tqdm import tqdm
from itertools import combinations
class KIDEvaluator:
"""
Kernel Inception Distance (KID) 评估器
KID 使用多项式核计算 Maximum Mean Discrepancy (MMD),
相比 FID 更适合小样本场景(MVTec 每类仅有数十张图像)
"""
def __init__(self, device="cuda"):
self.device = device
# 加载 InceptionV3 模型,使用 pool3 层特征(2048维)
self.inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
self.inception.fc = nn.Identity() # 移除分类头
self.inception = self.inception.to(device)
self.inception.eval()
# InceptionV3 的输入预处理
self.preprocess = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@torch.no_grad()
def extract_features(self, images):
"""
从图像中提取 InceptionV3 特征
Args:
images: PIL Image 列表或 Tensor [N, 3, H, W]
Returns:
features: [N, 2048] 特征向量
"""
if isinstance(images, list):
# 处理 PIL Image 列表
tensors = []
for img in images:
if isinstance(img, str):
img = Image.open(img).convert('RGB')
tensor = self.preprocess(img)
tensors.append(tensor)
images = torch.stack(tensors)
images = images.to(self.device)
# 批量处理以避免显存溢出
batch_size = 32
features_list = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
feat = self.inception(batch)
features_list.append(feat.cpu())
return torch.cat(features_list, dim=0)
def polynomial_kernel(self, x, y, degree=3, gamma=None, coef0=1):
"""
计算多项式核
k(x, y) = (gamma * <x, y> + coef0)^degree
"""
if gamma is None:
gamma = 1.0 / x.shape[1]
return (gamma * torch.mm(x, y.t()) + coef0) ** degree
def compute_mmd(self, x, y):
"""
计算 Maximum Mean Discrepancy (MMD)
MMD^2 = E[k(x,x')] - 2*E[k(x,y)] + E[k(y,y')]
"""
k_xx = self.polynomial_kernel(x, x)
k_yy = self.polynomial_kernel(y, y)
k_xy = self.polynomial_kernel(x, y)
# 无偏估计
n = x.shape[0]
m = y.shape[0]
# 移除对角线元素(自身与自身的比较)
mmd = (k_xx.sum() - k_xx.trace()) / (n * (n - 1))
mmd += (k_yy.sum() - k_yy.trace()) / (m * (m - 1))
mmd -= 2 * k_xy.mean()
return mmd
def compute_kid(self, real_images, gen_images, num_subsets=100, subset_size=None):
"""
计算 KID 分数
Args:
real_images: 真实缺陷图像(PIL Image 列表或路径列表)
gen_images: 生成缺陷图像(PIL Image 列表或路径列表)
num_subsets: 子集采样次数(用于计算均值和方差)
subset_size: 每个子集大小,默认 min(len(real), len(gen))
Returns:
kid_mean: KID 均值(越小越好)
kid_std: KID 标准差
"""
print("提取真实图像特征...")
real_features = self.extract_features(real_images)
print(f" 真实图像特征形状: {real_features.shape}")
print("提取生成图像特征...")
gen_features = self.extract_features(gen_images)
print(f" 生成图像特征形状: {gen_features.shape}")
if subset_size is None:
subset_size = min(len(real_features), len(gen_features))
# 多次子集采样计算 KID
kid_scores = []
for _ in range(num_subsets):
# 随机采样子集
idx_real = np.random.choice(len(real_features), subset_size, replace=False)
idx_gen = np.random.choice(len(gen_features), subset_size, replace=False)
# 计算 MMD
mmd = self.compute_mmd(
real_features[idx_real],
gen_features[idx_gen]
)
kid_scores.append(mmd.item())
return np.mean(kid_scores), np.std(kid_scores)
class ICLPIPSEvaluator:
"""
Inter-image Contextual LPIPS (IC-LPIPS) 评估器
计算生成图像之间的感知差异,用于评估生成的多样性。
IC-LPIPS 越大表示生成的图像越多样化。
参考论文: "Few-shot Image Generation via Cross-domain Correspondence" (CVPR 2021)
"""
def __init__(self, net='vgg', device="cuda"):
self.device = device
# 使用标准 LPIPS(非空间版本)
self.lpips_net = lpips.LPIPS(net=net, spatial=False).to(device)
self.lpips_net.eval()
# 图像预处理(转换到 [-1, 1] 范围)
self.preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def _load_image(self, img):
"""加载并预处理图像"""
if isinstance(img, str):
img = Image.open(img).convert('RGB')
if isinstance(img, Image.Image):
img = self.preprocess(img)
return img.to(self.device)
@torch.no_grad()
def compute_pairwise_lpips(self, img1, img2):
"""
计算两张图像之间的 LPIPS 分数
Args:
img1: 图像1 (路径、PIL Image 或 Tensor)
img2: 图像2 (路径、PIL Image 或 Tensor)
Returns:
lpips_score: LPIPS 分数(越大表示差异越大)
"""
img1_tensor = self._load_image(img1).unsqueeze(0)
img2_tensor = self._load_image(img2).unsqueeze(0)
lpips_score = self.lpips_net(img1_tensor, img2_tensor)
return lpips_score.item()
@torch.no_grad()
def compute_ic_lpips(self, generated_images, max_pairs=1000):
"""
计算生成图像集合的 IC-LPIPS 分数(多样性评估)
计算所有生成图像对之间的 LPIPS 距离的平均值。
值越大表示生成的图像越多样化。
Args:
generated_images: 生成图像列表(路径、PIL Image 或 Tensor)
max_pairs: 最大计算的图像对数量(用于大数据集)
Returns:
ic_lpips_mean: IC-LPIPS 均值(越大越好,表示多样性越高)
ic_lpips_std: IC-LPIPS 标准差
"""
n_images = len(generated_images)
if n_images < 2:
print("警告: 图像数量不足,无法计算 IC-LPIPS")
return float('nan'), float('nan')
# 预加载所有图像
print(f"加载 {n_images} 张生成图像...")
image_tensors = []
for img in tqdm(generated_images, desc="加载图像"):
img_tensor = self._load_image(img)
image_tensors.append(img_tensor)
# 堆叠成批次
image_batch = torch.stack(image_tensors, dim=0)
# 生成所有图像对的索引
all_pairs = list(combinations(range(n_images), 2))
n_pairs = len(all_pairs)
print(f"总计 {n_pairs} 个图像对")
# 如果图像对太多,随机采样
if n_pairs > max_pairs:
print(f"随机采样 {max_pairs} 个图像对进行计算")
selected_pairs = np.random.choice(n_pairs, max_pairs, replace=False)
pairs_to_compute = [all_pairs[i] for i in selected_pairs]
else:
pairs_to_compute = all_pairs
# 计算所有选定的图像对的 LPIPS
lpips_scores = []
# 批量计算以提高效率
batch_size = 32
for i in tqdm(range(0, len(pairs_to_compute), batch_size), desc="计算 IC-LPIPS"):
batch_pairs = pairs_to_compute[i:i+batch_size]
img1_batch = torch.stack([image_batch[p[0]] for p in batch_pairs], dim=0)
img2_batch = torch.stack([image_batch[p[1]] for p in batch_pairs], dim=0)
scores = self.lpips_net(img1_batch, img2_batch)
lpips_scores.extend(scores.squeeze().cpu().tolist() if len(batch_pairs) > 1 else [scores.item()])
ic_lpips_mean = np.mean(lpips_scores)
ic_lpips_std = np.std(lpips_scores)
return ic_lpips_mean, ic_lpips_std
def collect_generated_images(directory):
"""
收集目录中所有生成的图像(用于评估)
只收集 *_generated.png 文件
"""
images = []
for root, dirs, files in os.walk(directory):
for file in files:
# 只收集生成的图像
if file.endswith('_generated.png'):
images.append(os.path.join(root, file))
return images
def collect_real_defect_images(directory):
"""
收集真实缺陷图像目录中的所有图像(用于 KID 评估)
"""
images = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(('.png', '.jpg', '.jpeg')):
# 排除掩码文件
if '_mask' not in file:
images.append(os.path.join(root, file))
return images
def evaluate(args):
"""
执行评估
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 初始化评估器
print("\n初始化评估器...")
kid_evaluator = KIDEvaluator(device=device)
ic_lpips_evaluator = ICLPIPSEvaluator(device=device)
# 收集图像
print("\n收集图像...")
# 收集所有生成的图像
gen_images_all = collect_generated_images(args.generated_dir)
print(f" 生成图像数量: {len(gen_images_all)}")
# 收集所有真实缺陷图像(用于 KID)
real_images_all = collect_real_defect_images(args.real_dir)
print(f" 真实图像数量: {len(real_images_all)}")
# 计算 KID(质量评估)
print("\n计算 KID(质量评估)...")
if len(gen_images_all) > 0 and len(real_images_all) > 0:
kid_mean, kid_std = kid_evaluator.compute_kid(
real_images_all, gen_images_all,
num_subsets=min(100, len(gen_images_all)),
subset_size=min(len(gen_images_all), len(real_images_all))
)
print(f" KID: {kid_mean:.6f} ± {kid_std:.6f} (越小越好)")
else:
kid_mean, kid_std = float('nan'), float('nan')
print(" 警告: 图像数量不足,无法计算 KID")
# 计算 IC-LPIPS(多样性评估)
print("\n计算 IC-LPIPS(多样性评估)...")
if len(gen_images_all) >= 2:
ic_lpips_mean, ic_lpips_std = ic_lpips_evaluator.compute_ic_lpips(
gen_images_all,
max_pairs=min(1000, len(gen_images_all) * (len(gen_images_all) - 1) // 2)
)
print(f" IC-LPIPS: {ic_lpips_mean:.6f} ± {ic_lpips_std:.6f} (越大越好,表示多样性越高)")
else:
ic_lpips_mean, ic_lpips_std = float('nan'), float('nan')
print(" 警告: 生成图像数量不足,无法计算 IC-LPIPS")
# 保存结果
timestamp = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
result_row = [
timestamp,
args.class_name,
args.config_name,
args.category_type,
f"{kid_mean:.6f}",
f"{kid_std:.6f}",
f"{ic_lpips_mean:.6f}",
f"{ic_lpips_std:.6f}"
]
# 追加到 CSV 文件
file_exists = os.path.exists(args.output_csv)
with open(args.output_csv, 'a', newline='') as f:
writer = csv.writer(f)
if not file_exists:
# 写入表头
writer.writerow([
'timestamp', 'class', 'config', 'category_type',
'KID_mean', 'KID_std', 'IC_LPIPS_mean', 'IC_LPIPS_std'
])
writer.writerow(result_row)
print(f"\n结果已追加到: {args.output_csv}")
# 返回结果
return {
'kid_mean': kid_mean,
'kid_std': kid_std,
'ic_lpips_mean': ic_lpips_mean,
'ic_lpips_std': ic_lpips_std
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DefectFill 评估模块")
parser.add_argument("--generated_dir", type=str, required=True,
help="生成图像目录")
parser.add_argument("--real_dir", type=str, required=True,
help="真实缺陷图像目录")
parser.add_argument("--output_csv", type=str, required=True,
help="输出 CSV 文件路径")
parser.add_argument("--class_name", type=str, required=True,
help="类别名称")
parser.add_argument("--config_name", type=str, required=True,
help="配置名称")
parser.add_argument("--category_type", type=str, default="unknown",
choices=["object", "texture", "unknown"],
help="类别类型 (object/texture)")
# 移除了 --mask_dir 参数,因为 IC-LPIPS 不需要掩码
args = parser.parse_args()
evaluate(args)