基于 Stable Diffusion Inpainting 的工业缺陷生成与修复实验项目。
论文:DefectFill: Realistic Defect Generation with Inpainting Diffusion Model for Visual Inspection
arXiv: https://arxiv.org/abs/2503.13985
| 功能 | 文件 |
|---|---|
| 模型封装(UNet + 文本编码器 + VAE,LoRA 微调) | model.py |
| 训练(注意力损失、梯度累积、TensorBoard) | train.py |
| 推理 / 生成(CFG、迭代式背景保持) | inference.py |
| 数据加载与预处理 | data_loader.py |
| 检查点与工具函数 | utils.py |
| 评估(KID、IC-LPIPS) | evaluate.py |
| 实验结果可视化 | visualize_results.py |
推荐 Python 3.11(已在 3.11.9 下测试)。
pip install -r requirements.txt主要依赖(见 requirements.txt):
- torch / torchvision
- diffusers
- transformers
- accelerate, bitsandbytes, peft
- lpips(评估)
- albumentations, opencv-python
- tqdm
conda create -n defectfill python=3.11 -y
conda activate defectfill
pip install -r requirements.txt与 data_loader.py 约定一致:
MVTec2/
bottle/
train/
defective/
broken_large/
000.png ...
defective_masks/
broken_large/
000_mask.png ...
test/
good/
000.png ...
已按上述目录整理好的 MVTec 数据可从百度网盘下载:
链接: https://pan.baidu.com/s/1xwr5AkrmLi6ahPvU5eQL8Q 提取码: g84e
(由百度网盘 SVIP 分享)
- 每批样本字段示例:
image,mask,background,adjusted_mask,is_defect,object_class,defect_type - 仅使用缺陷样本:
is_defect == 1 - 两条分支(同一缺陷样本):
- 真实缺陷掩码:学习缺陷语义(Defect Loss)
- prompt 模板:
"A photo of {defect_type}"
- prompt 模板:
- 随机矩形掩码:学习物体完整性(Object Loss)
- prompt 模板:
"A {object_class} with {defect_type}"
- prompt 模板:
- 真实缺陷掩码:学习缺陷语义(Defect Loss)
总损失(见 train.py 约 235 行):
L_total = λ_defect * L_defect + λ_object * L_object + λ_attn * L_attention
可选参数:--lambda_defect, --lambda_obj, --lambda_attn, --alpha(Object 分支背景权重)。
python train.py \
--data_dir "/path/to/MVTec2" \
--object_class bottle \
--defect_type broken_large \
--output_dir "./models/bottle" \
--config_name base \
--lambda_defect 0.5 --lambda_obj 0.2 --alpha 0.3 \
--batch_size 2 \
--max_train_steps 2000常用可选参数(见 train.py 中 ArgumentParser):
--lora_rank,--lora_alpha--text_encoder_lr,--unet_lr--save_steps,--resume_from--gradient_accumulation_steps,--lr_warmup_steps--seed
- 梯度累积、学习率预热
- 检查点:按
save_steps保存,结束时可保存checkpoint_final.pt - 日志:
output_dir/train_log.txt,TensorBoard:output_dir/tensorboard
tensorboard --logdir ./models/bottle/tensorboard --port 6006脚本:inference.py。使用 model.generate():9 通道输入、Classifier-Free Guidance、迭代式背景保持、可学习 <defect> token。
单张示例:
python inference.py \
--checkpoint checkpoint_final.pt \
--object_class bottle \
--defect_type broken_large \
--image_path 000.png \
--mask_path 012_mask.png \
--data_dir "/path/to/MVTec2" \
--output_dir ./generated输出:修复/生成图像,可选保存注意力图(在模型中启用 attention_maps)。
脚本:evaluate.py。指标说明:
- KID:生成图像与真实图像的分布距离(质量),越小越好
- IC-LPIPS:生成图像间感知差异(多样性),越大越好
单次评估示例:
python evaluate.py \
--generated_dir ./experiments/bottle/base/generated \
--real_dir "/path/to/MVTec2/bottle/train/defective/broken_large" \
--output_csv evaluation_results.csv \
--class_name bottle \
--config_name base \
--category_type object--category_type 可选:object / texture / unknown。
脚本:visualize_results.py。根据评估 CSV 生成:
- 热力图:类别 × 配置的 KID / IC-LPIPS 矩阵
- 散点图:质量–多样性权衡(KID vs IC-LPIPS)
- 分组柱状图:各配置在不同类别上的对比
- 汇总统计:
summary_statistics.csv
CSV 需包含列:class, config, category_type, KID_mean, KID_std, IC_LPIPS_mean, IC_LPIPS_std。
示例:
python visualize_results.py \
--csv_path evaluation_results.csv \
--output_dir ./figures| 脚本 | 说明 |
|---|---|
| run_train_test.sh | 小规模测试训练(如 bottle、carpet × base/tex/obj) |
| run_train_all.sh | 全量训练(10 类 × 3 配置,需在脚本中配置 OBJECT_CLASSES 与 ALL_CLASSES) |
| run_inference_test.sh | 对测试实验目录做批量推理 |
| run_inference_all.sh | 对完整实验目录做批量推理 |
| run_evaluation_test.sh | 对测试实验评估并写 CSV |
| run_evaluation.sh | 对完整实验评估并写 CSV |
| run_visualization.sh | 根据 CSV 自动生成图表(可指定 test/full 或 CSV 路径) |
脚本内路径(如 DATA_DIR、EXPERIMENTS_DIR、SCRIPT_DIR)需按本机环境修改。
model.py 中 DefectFillModel 加载 Stable Diffusion Inpainting(UNet + 文本编码器 + VAE)。
微调方式:LoRA(--lora_rank),仅训练 LoRA 参数;具体可训练模块见代码。
# 1. 环境
conda create -n defectfill python=3.11 -y
conda activate defectfill
pip install -r requirements.txt
# 2. 单类训练
python train.py --data_dir ./MVTec2 --object_class bottle \
--defect_type broken_large --output_dir ./models/bottle --config_name base
# 3. 监控
tensorboard --logdir ./models/bottle/tensorboard
# 4. 推理
python inference.py --checkpoint ./models/bottle/checkpoints/checkpoint_final.pt \
--object_class bottle --defect_type broken_large \
--data_dir ./MVTec2 --output_dir ./generated
# 5. 评估
python evaluate.py --generated_dir ./generated \
--real_dir ./MVTec2/bottle/train/defective/broken_large \
--output_csv results.csv --class_name bottle --config_name base --category_type object
# 6. 可视化
python visualize_results.py --csv_path results.csv --output_dir ./figures