Skip to content

lhj23333/DefectFillcore

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DefectFill 非官方实现

基于 Stable Diffusion Inpainting 的工业缺陷生成与修复实验项目。
论文:DefectFill: Realistic Defect Generation with Inpainting Diffusion Model for Visual Inspection
arXiv: https://arxiv.org/abs/2503.13985


1. 主要功能与文件

功能 文件
模型封装(UNet + 文本编码器 + VAE,LoRA 微调) model.py
训练(注意力损失、梯度累积、TensorBoard) train.py
推理 / 生成(CFG、迭代式背景保持) inference.py
数据加载与预处理 data_loader.py
检查点与工具函数 utils.py
评估(KID、IC-LPIPS) evaluate.py
实验结果可视化 visualize_results.py

2. 环境配置

2.1 Python 版本

推荐 Python 3.11(已在 3.11.9 下测试)。

2.2 安装依赖

pip install -r requirements.txt

主要依赖(见 requirements.txt):

  • torch / torchvision
  • diffusers
  • transformers
  • accelerate, bitsandbytes, peft
  • lpips(评估)
  • albumentations, opencv-python
  • tqdm

2.3 推荐 Conda 环境

conda create -n defectfill python=3.11 -y
conda activate defectfill
pip install -r requirements.txt

3. 数据集

3.1 目录结构示例

data_loader.py 约定一致:

MVTec2/
  bottle/
    train/
      defective/
        broken_large/
          000.png ...
      defective_masks/
        broken_large/
          000_mask.png ...
    test/
      good/
        000.png ...

3.2 数据集获取

已按上述目录整理好的 MVTec 数据可从百度网盘下载:
链接: https://pan.baidu.com/s/1xwr5AkrmLi6ahPvU5eQL8Q 提取码: g84e
(由百度网盘 SVIP 分享)


4. 训练数据与损失

  • 每批样本字段示例:image, mask, background, adjusted_mask, is_defect, object_class, defect_type
  • 仅使用缺陷样本:is_defect == 1
  • 两条分支(同一缺陷样本):
    1. 真实缺陷掩码:学习缺陷语义(Defect Loss)
      • prompt 模板:"A photo of {defect_type}"
    2. 随机矩形掩码:学习物体完整性(Object Loss)
      • prompt 模板:"A {object_class} with {defect_type}"

总损失(见 train.py 约 235 行):

L_total = λ_defect * L_defect + λ_object * L_object + λ_attn * L_attention

可选参数:--lambda_defect, --lambda_obj, --lambda_attn, --alpha(Object 分支背景权重)。


5. 训练

5.1 单次训练命令

python train.py \
  --data_dir "/path/to/MVTec2" \
  --object_class bottle \
  --defect_type broken_large \
  --output_dir "./models/bottle" \
  --config_name base \
  --lambda_defect 0.5 --lambda_obj 0.2 --alpha 0.3 \
  --batch_size 2 \
  --max_train_steps 2000

常用可选参数(见 train.py 中 ArgumentParser):

  • --lora_rank, --lora_alpha
  • --text_encoder_lr, --unet_lr
  • --save_steps, --resume_from
  • --gradient_accumulation_steps, --lr_warmup_steps
  • --seed

5.2 训练特性

  • 梯度累积、学习率预热
  • 检查点:按 save_steps 保存,结束时可保存 checkpoint_final.pt
  • 日志:output_dir/train_log.txt,TensorBoard:output_dir/tensorboard
tensorboard --logdir ./models/bottle/tensorboard --port 6006

6. 推理 / 生成

脚本:inference.py。使用 model.generate():9 通道输入、Classifier-Free Guidance、迭代式背景保持、可学习 <defect> token。

单张示例:

python inference.py \
  --checkpoint checkpoint_final.pt \
  --object_class bottle \
  --defect_type broken_large \
  --image_path 000.png \
  --mask_path 012_mask.png \
  --data_dir "/path/to/MVTec2" \
  --output_dir ./generated

输出:修复/生成图像,可选保存注意力图(在模型中启用 attention_maps)。


7. 评估

脚本:evaluate.py。指标说明:

  • KID:生成图像与真实图像的分布距离(质量),越小越好
  • IC-LPIPS:生成图像间感知差异(多样性),越大越好

单次评估示例:

python evaluate.py \
  --generated_dir ./experiments/bottle/base/generated \
  --real_dir "/path/to/MVTec2/bottle/train/defective/broken_large" \
  --output_csv evaluation_results.csv \
  --class_name bottle \
  --config_name base \
  --category_type object

--category_type 可选:object / texture / unknown


8. 可视化

脚本:visualize_results.py。根据评估 CSV 生成:

  • 热力图:类别 × 配置的 KID / IC-LPIPS 矩阵
  • 散点图:质量–多样性权衡(KID vs IC-LPIPS)
  • 分组柱状图:各配置在不同类别上的对比
  • 汇总统计:summary_statistics.csv

CSV 需包含列:class, config, category_type, KID_mean, KID_std, IC_LPIPS_mean, IC_LPIPS_std

示例:

python visualize_results.py \
  --csv_path evaluation_results.csv \
  --output_dir ./figures

9. 脚本(scripts/)

脚本 说明
run_train_test.sh 小规模测试训练(如 bottle、carpet × base/tex/obj)
run_train_all.sh 全量训练(10 类 × 3 配置,需在脚本中配置 OBJECT_CLASSESALL_CLASSES
run_inference_test.sh 对测试实验目录做批量推理
run_inference_all.sh 对完整实验目录做批量推理
run_evaluation_test.sh 对测试实验评估并写 CSV
run_evaluation.sh 对完整实验评估并写 CSV
run_visualization.sh 根据 CSV 自动生成图表(可指定 test/full 或 CSV 路径)

脚本内路径(如 DATA_DIREXPERIMENTS_DIRSCRIPT_DIR)需按本机环境修改。


10. 模型与可训练部分

model.pyDefectFillModel 加载 Stable Diffusion Inpainting(UNet + 文本编码器 + VAE)。
微调方式:LoRA(--lora_rank),仅训练 LoRA 参数;具体可训练模块见代码。


11. 快速开始

# 1. 环境
conda create -n defectfill python=3.11 -y
conda activate defectfill
pip install -r requirements.txt

# 2. 单类训练
python train.py --data_dir ./MVTec2 --object_class bottle \
  --defect_type broken_large --output_dir ./models/bottle --config_name base

# 3. 监控
tensorboard --logdir ./models/bottle/tensorboard

# 4. 推理
python inference.py --checkpoint ./models/bottle/checkpoints/checkpoint_final.pt \
  --object_class bottle --defect_type broken_large \
  --data_dir ./MVTec2 --output_dir ./generated

# 5. 评估
python evaluate.py --generated_dir ./generated \
  --real_dir ./MVTec2/bottle/train/defective/broken_large \
  --output_csv results.csv --class_name bottle --config_name base --category_type object

# 6. 可视化
python visualize_results.py --csv_path results.csv --output_dir ./figures

About

Graduation Design -- Real Defect Generation Based on Inpainting Diffusion

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors