Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions FMF/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import os
import torch
import logging
import shutil
from typing import Dict, Optional, Union

logger = logging.getLogger(__name__)

def save_checkpoint(
state: Dict,
is_best: bool,
output_dir: str,
filename: str = 'checkpoint.pth'
) -> None:
"""
保存模型检查点

Args:
state: 包含模型状态、优化器状态等的字典
is_best: 是否为最佳模型
output_dir: 输出目录
filename: 检查点文件名
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# 保存检查点
filepath = os.path.join(output_dir, filename)
torch.save(state, filepath)
logger.info(f'Saved checkpoint to {filepath}')

# 如果是最佳模型,保存一个副本
if is_best:
best_filepath = os.path.join(output_dir, 'model_best.pth')
shutil.copyfile(filepath, best_filepath)
logger.info(f'Saved best model to {best_filepath}')

def load_checkpoint(
checkpoint_path: str,
model: Optional[torch.nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None
) -> Dict:
"""
加载模型检查点

Args:
checkpoint_path: 检查点文件路径
model: 要加载权重的模型
optimizer: 要加载状态的优化器
scheduler: 要加载状态的学习率调度器

Returns:
包含检查点内容的字典
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f'Checkpoint not found at {checkpoint_path}')

logger.info(f'Loading checkpoint from {checkpoint_path}')
checkpoint = torch.load(checkpoint_path, map_location='cpu')

# 加载模型权重
if model is not None and 'model_state_dict' in checkpoint:
try:
model.load_state_dict(checkpoint['model_state_dict'])
logger.info('Successfully loaded model weights')
except Exception as e:
logger.error(f'Error loading model weights: {e}')
raise

# 加载优化器状态
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
try:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
logger.info('Successfully loaded optimizer state')
except Exception as e:
logger.error(f'Error loading optimizer state: {e}')
raise

# 加载学习率调度器状态
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
try:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info('Successfully loaded scheduler state')
except Exception as e:
logger.error(f'Error loading scheduler state: {e}')
raise

return checkpoint

def get_last_checkpoint(output_dir: str) -> Optional[str]:
"""
获取最新的检查点文件路径

Args:
output_dir: 输出目录

Returns:
最新的检查点文件路径,如果没有找到则返回None
"""
if not os.path.exists(output_dir):
return None

checkpoints = [f for f in os.listdir(output_dir) if f.endswith('.pth')]
if not checkpoints:
return None

# 按修改时间排序,返回最新的检查点
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)))
return os.path.join(output_dir, checkpoints[-1])

def load_pretrained_weights(
model: torch.nn.Module,
pretrained_path: str,
strict: bool = True
) -> None:
"""
加载预训练权重

Args:
model: 要加载权重的模型
pretrained_path: 预训练权重文件路径
strict: 是否严格匹配权重
"""
if not os.path.exists(pretrained_path):
raise FileNotFoundError(f'Pretrained weights not found at {pretrained_path}')

logger.info(f'Loading pretrained weights from {pretrained_path}')
state_dict = torch.load(pretrained_path, map_location='cpu')

if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']

try:
model.load_state_dict(state_dict, strict=strict)
logger.info('Successfully loaded pretrained weights')
except Exception as e:
logger.error(f'Error loading pretrained weights: {e}')
raise

def save_training_state(
state: Dict,
output_dir: str,
filename: str = 'training_state.pth'
) -> None:
"""
保存训练状态(不包含模型权重)

Args:
state: 包含训练状态的字典
output_dir: 输出目录
filename: 文件名
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)

filepath = os.path.join(output_dir, filename)
torch.save(state, filepath)
logger.info(f'Saved training state to {filepath}')

def cleanup_checkpoints(output_dir: str, keep_last_n: int = 5) -> None:
"""
清理旧的检查点文件,只保留最新的N个

Args:
output_dir: 输出目录
keep_last_n: 保留最新的检查点数量
"""
if not os.path.exists(output_dir):
return

checkpoints = [f for f in os.listdir(output_dir) if f.endswith('.pth')]
if len(checkpoints) <= keep_last_n:
return

# 按修改时间排序
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)))

# 删除旧的检查点
for checkpoint in checkpoints[:-keep_last_n]:
filepath = os.path.join(output_dir, checkpoint)
os.remove(filepath)
logger.info(f'Removed old checkpoint: {filepath}')
81 changes: 51 additions & 30 deletions tools/test_cls.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,78 @@
import os
from tqdm import tqdm
import time
import os # 导入os库,用于文件和目录操作
from tqdm import tqdm # 导入tqdm库,用于显示进度条
import time # 导入time库,用于时间相关的操作

import torch
from torch.utils.data import DataLoader
import torch # 导入PyTorch库,用于深度学习模型的定义和训练
from torch.utils.data import DataLoader # 从PyTorch中导入DataLoader,用于加载数据

from FMF.models import build_model
from FMF.datasets import build_dataset
from FMF.models import build_model # 从FMF.models模块导入build_model函数,用于构建模型
from FMF.datasets import build_dataset # 从FMF.datasets模块导入build_dataset函数,用于构建数据集

from FMF.utils.parser import parse_args, load_config
from FMF.utils.others import get_time
from FMF.utils.metrics import ClassificationMetric
from FMF.utils.parser import parse_args, load_config # 从FMF.utils.parser模块导入parse_args和load_config函数,用于解析命令行参数和加载配置
from FMF.utils.others import get_time # 从FMF.utils.others模块导入get_time函数,用于获取当前时间
from FMF.utils.metrics import ClassificationMetric # 从FMF.utils.metrics模块导入ClassificationMetric类,用于计算分类指标


def main():
"""
主函数,执行测试流程。
"""
# 解析命令行参数
args = parse_args()
# 根据命令行参数加载配置
cfg = load_config(args)
# 检查预训练模型是否存在,如果不存在则抛出异常
assert os.path.exists(cfg.MODEL.PRETRAINED), f'{cfg.MODEL.PRETRAINED}不存在!'

# 设置设备(GPU或CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载测试集
print('[{}] Loading test set...'.format(get_time()))
val_set = build_dataset(name=cfg.TEST.DATASET, cfg=cfg, split='test')
val_loader = DataLoader(val_set, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False,
num_workers=cfg.DATA_LOADER.NUM_WORKERS, pin_memory=cfg.DATA_LOADER.PIN_MEMORY)

# 构建模型
print('[{}] Constructing model...'.format(get_time()))

model = build_model(cfg)
model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED))
model.to(device)
# 加载预训练模型的权重

# model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED))
checkpoint = torch.load(cfg.MODEL.PRETRAINED)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)

test_metric = ClassificationMetric(numClass=cfg.MODEL.NUM_CLASSES)
# 将模型移动到指定设备(GPU或CPU)
model.to(device)
numClass = cfg.MODEL.NUM_CLASSES
# 初始化分类指标计算工具
test_metric = ClassificationMetric(numClass)

# 将模型设置为评估模式
model.eval()
# 开始测试
print('[{}] Testing...'.format(get_time()))
with torch.no_grad():
start_time = time.time()
for i, xy, in enumerate(tqdm(val_loader, ncols=100)):
y = xy[-1]
x = [ipt.to(device) for ipt in xy[:-1]]
y_hat = model(*x)
with torch.no_grad(): # 禁用梯度计算,以节省内存和计算资源
start_time = time.time() # 记录测试开始时间
for i, xy in enumerate(tqdm(val_loader, ncols=100)): # 遍历测试集,并显示进度条
y = xy[-1] # 获取标签
x = [ipt.to(device) for ipt in xy[:-1]] # 将输入数据移动到设备
y_hat = model(*x) # 前向传播,获取模型的预测结果

# 将预测结果和真实标签添加到分类指标计算工具中
test_metric.addBatch(torch.argmax(y_hat.detach().cpu(), dim=1).numpy(), y.cpu().numpy())
end_time = time.time()
fps = len(val_set) / (end_time - start_time)
end_time = time.time() # 记录测试结束时间
fps = len(val_set) / (end_time - start_time) # 计算每秒帧数(FPS)

# 打印测试结果
print('[{}] Tests complete!'.format(get_time()))
print('ACC: {}'.format(test_metric.Accuracy()))
print('F1 : {}'.format(test_metric.F1Score()))
print('FDR: {}'.format(test_metric.FalsePositiveRate()))
print('MDR: {}'.format(test_metric.FalseNegativeRate()))
print('FPS: {}'.format(fps))
print('ACC: {}'.format(test_metric.Accuracy())) # 打印准确率
print('F1 : {}'.format(test_metric.F1Score())) # 打印F1分数
print('FDR: {}'.format(test_metric.FalsePositiveRate())) # 打印假阳性率
print('MDR: {}'.format(test_metric.FalseNegativeRate())) # 打印假阴性率
print('FPS: {}'.format(fps)) # 打印每秒帧数


if __name__ == '__main__':
main()
main() # 执行主函数
Loading