From fbaa8ea8014ac2508a1ead2943ca2d67715b2b00 Mon Sep 17 00:00:00 2001 From: ZhaoTianyou Date: Thu, 6 Mar 2025 16:42:54 +0800 Subject: [PATCH 1/3] Add training code --- tools/train_cls.py | 305 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 tools/train_cls.py diff --git a/tools/train_cls.py b/tools/train_cls.py new file mode 100644 index 0000000..59718d6 --- /dev/null +++ b/tools/train_cls.py @@ -0,0 +1,305 @@ +import os +import torch +import time # 导入time库,用于时间相关的操作 +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +import logging +from datetime import datetime +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter # 导入TensorBoard +from FMF.utils.parser import parse_args, load_config # 从FMF.utils.parser模块导入parse_args和load_config函数,用于解析命令行参数和加载配置 +from FMF.models import build_model +from FMF.datasets.build import build_dataset +from FMF.utils.checkpoint import save_checkpoint, load_checkpoint +from FMF.utils.others import get_time +from FMF.utils.metrics import ClassificationMetric # 从FMF.utils.metrics模块导入ClassificationMetric类,用于计算分类指标 + + +def setup_logger(output_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(os.path.join(output_dir, 'train.log')) + # logging.StreamHandler() + ] + ) + return logging.getLogger(__name__) + +class Trainer: + def __init__(self, cfg, output_dir): + self.cfg = cfg + self.output_dir = output_dir + self.logger = setup_logger(output_dir) + + # 创建TensorBoard的SummaryWriter + self.writer = SummaryWriter('/root/tf-logs') + + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # 构建数据集和数据加载器 + print('[{}] Loading test and train set...'.format(get_time())) + self.train_dataset = build_dataset(name=cfg.TRAIN.DATASET,cfg=cfg, split='train') + self.val_dataset = build_dataset(name=cfg.TEST.DATASET, cfg=cfg, split='test') + + self.train_loader = DataLoader( + self.train_dataset, + batch_size=cfg.TRAIN.BATCH_SIZE, + shuffle=True, + num_workers=cfg.DATA_LOADER.NUM_WORKERS, + pin_memory=cfg.DATA_LOADER.PIN_MEMORY + ) + self.val_loader = DataLoader( + self.val_dataset, + batch_size=cfg.TEST.BATCH_SIZE, + shuffle=False, + num_workers=cfg.DATA_LOADER.NUM_WORKERS, + pin_memory=cfg.DATA_LOADER.PIN_MEMORY + ) + # 构建模型 + + print('[{}] Constructing model...'.format(get_time())) + self.model = build_model(cfg) + + # 输出模型结构 + print(self.model) + + # 随机初始化模型参数 + print('[{}] Initializing model parameters...'.format(get_time())) + for m in self.model.modules(): + if isinstance(m, nn.Linear): + # 使用xavier初始化线性层 + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + # 层归一化层初始化 + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.MultiheadAttention): + # 注意力层的初始化 + nn.init.xavier_uniform_(m.in_proj_weight) + nn.init.xavier_uniform_(m.out_proj.weight) + if m.in_proj_bias is not None: + nn.init.constant_(m.in_proj_bias, 0.0) + if m.out_proj.bias is not None: + nn.init.constant_(m.out_proj.bias, 0.0) + elif isinstance(m, nn.Conv2d): + # 如果有卷积层,使用kaiming初始化 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + + # 将模型移动到指定设备(GPU或CPU) + self.model = self.model.to(self.device) + + self.numClass = cfg.MODEL.NUM_CLASSES + # 初始化分类指标计算工具 + self.test_metric = ClassificationMetric(self.numClass) + self.train_metric = ClassificationMetric(self.numClass) + + # 定义损失函数和优化器 + self.criterion = nn.CrossEntropyLoss() + self.criterion.to(self.device) + + self.optimizer = optim.Adam( + self.model.parameters(), + lr=cfg.SOLVER.BASE_LR + ) + + # 学习率调度器 + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=cfg.SOLVER.MAX_EPOCH, + eta_min=cfg.SOLVER.COSINE_END_LR + ) + + self.start_epoch = 0 + self.best_acc = 0.0 + + + def train_epoch(self, epoch): + self.model.train() + self.train_metric.reset() + pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}/{self.cfg.SOLVER.MAX_EPOCH}', ncols=100) + total_loss = 0.0 # 添加损失累计变量 + for batch_idx, xy in enumerate(pbar): + y = xy[-1].to(self.device) # 获取标签 + x = [ipt.to(self.device) for ipt in xy[:-1]] + + # 前向传播 + outputs = self.model(*x) + loss = self.criterion(outputs, y) + # 反向传播和优化 + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + # 累计损失 + total_loss += loss.item() + # 计算准确率 + self.train_metric.addBatch(torch.argmax(outputs.detach().cpu(), dim=1).numpy(), y.cpu().numpy()) + acc = self.train_metric.Accuracy() + f1 = self.train_metric.F1Score() + recall = self.train_metric.Recall() + self.avg_loss = total_loss / (batch_idx + 1) + + # 记录训练指标到TensorBoard + global_step = epoch * len(self.train_loader) + batch_idx + self.writer.add_scalar('Loss/train', self.avg_loss, global_step) + self.writer.add_scalar('Accuracy/train', acc, global_step) + self.writer.add_scalar('F1/train', f1, global_step) + self.writer.add_scalar('Recall/train', recall, global_step) + self.writer.add_scalar('Learning_rate', self.optimizer.param_groups[0]['lr'], global_step) + + # 更新进度条 + if batch_idx % 200 == 0: + pbar.set_postfix({ + 'loss': f'{self.avg_loss:.4f}' + }) + # 记录日志 + if batch_idx % self.cfg.PRINT_FREQ == 0: + self.logger.info( + f'Epoch: [{epoch}][{batch_idx}/{len(self.train_loader)}] ' + f'loss: {self.avg_loss:.4f} ' + f'acc: {acc:.4f} ' + f'F1: {f1:.4f} ' + f'recall: {recall:.4f}' + ) + + print('[{}] Train test complete! loss: {: .4f}, acc : {: .4f}, F1: {: .4f}, recall: {: .4f}'.format(get_time(), + self.avg_loss, + self.train_metric.Accuracy(), + self.train_metric.F1Score(), + self.train_metric.Recall() + )) + self.acc = self.train_metric.Accuracy() + + def validate(self): + self.model.eval() + self.test_metric.reset() + + # 开始测试 + print('[{}] Testing...'.format(get_time())) + with torch.no_grad(): # 禁用梯度计算,以节省内存和计算资源 + start_time = time.time() # 记录测试开始时间 + for i, xy in enumerate(tqdm(self.val_loader, ncols=100)): # 遍历测试集,并显示进度条 + y = xy[-1] # 获取标签 + x = [ipt.to(self.device) for ipt in xy[:-1]] # 将输入数据移动到设备 + y_hat = self.model(*x) # 前向传播,获取模型的预测结果 + + # 将预测结果和真实标签添加到分类指标计算工具中 + self.test_metric.addBatch(torch.argmax(y_hat.detach().cpu(), dim=1).numpy(), y.cpu().numpy()) + end_time = time.time() # 记录测试结束时间 + fps = len(self.val_loader) / (end_time - start_time) # 计算每秒帧数(FPS) + + # 记录验证指标到TensorBoard + val_acc = self.test_metric.Accuracy() + val_f1 = self.test_metric.F1Score() + val_fdr = self.test_metric.FalsePositiveRate() + val_mdr = self.test_metric.FalseNegativeRate() + + self.writer.add_scalar('Accuracy/val', val_acc, self.current_epoch) + self.writer.add_scalar('F1/val', val_f1, self.current_epoch) + self.writer.add_scalar('FDR/val', val_fdr, self.current_epoch) + self.writer.add_scalar('MDR/val', val_mdr, self.current_epoch) + self.writer.add_scalar('FPS/val', fps, self.current_epoch) + + print('[{}] Tests complete! ACC: {}, F1 : {}, FDR: {}, MDR: {}, FPS: {}'.format(get_time(), + val_acc, + val_f1, + val_fdr, + val_mdr, + fps)) + + return val_acc, val_f1 + + def train(self): + for epoch in range(self.start_epoch, self.cfg.SOLVER.MAX_EPOCH): + self.current_epoch = epoch # 添加当前epoch的记录 + # 训练一个epoch + self.train_epoch(epoch) + # 验证 + val_acc, val_f1 = self.validate() + # 更新学习率 + self.scheduler.step() + # 记录日志 + self.logger.info( + f'Epoch: {epoch}, ' + f'Train Loss: {self.avg_loss:.4f}, ' + f'Train Acc: {self.acc:.4f}, ' + f'Val F1: {val_f1:.4f}, ' + f'Val Acc: {val_acc:.4f}' + ) + + # 保存最佳模型 + if val_acc > self.best_acc: + self.best_acc = val_acc + save_checkpoint( + { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'best_acc': self.best_acc, + }, + is_best=True, + output_dir=self.output_dir + ) + + # 定期保存检查点 + if (epoch + 1) % self.cfg.TRAIN.CHECKPOINT_PERIOD == 0: + save_checkpoint( + { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'best_acc': self.best_acc, + }, + is_best=False, + output_dir=self.output_dir + ) + + # 关闭TensorBoard的SummaryWriter + self.writer.close() + + def resume_from_checkpoint(self, checkpoint_path): + checkpoint = load_checkpoint(checkpoint_path) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + self.start_epoch = checkpoint['epoch'] + 1 + self.best_acc = checkpoint['best_acc'] + self.logger.info(f'Resumed from epoch {self.start_epoch} with best acc {self.best_acc}') + +def main(): + # 解析命令行参数 + args = parse_args() + # 根据命令行参数加载配置 + cfg = load_config(args) + + # 设置输出目录 + output_dir = os.path.join( + cfg.OUTPUT_DIR, + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + ) + + os.makedirs(output_dir, exist_ok=True) + + # 创建训练器 + trainer = Trainer(cfg, output_dir) + + # 如果指定了检查点,则从检查点恢复 + if args.resume: + trainer.resume_from_checkpoint(args.resume) + + # 开始训练 + trainer.train() + +if __name__ == '__main__': + main() \ No newline at end of file From e9e2c7195c2a5e1cf9a26fa2704122aa60ee52b3 Mon Sep 17 00:00:00 2001 From: ZhaoTianyou Date: Thu, 6 Mar 2025 16:50:38 +0800 Subject: [PATCH 2/3] Enable power-off reconnect --- FMF/utils/checkpoint.py | 183 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 FMF/utils/checkpoint.py diff --git a/FMF/utils/checkpoint.py b/FMF/utils/checkpoint.py new file mode 100644 index 0000000..ddcbfb8 --- /dev/null +++ b/FMF/utils/checkpoint.py @@ -0,0 +1,183 @@ +import os +import torch +import logging +import shutil +from typing import Dict, Optional, Union + +logger = logging.getLogger(__name__) + +def save_checkpoint( + state: Dict, + is_best: bool, + output_dir: str, + filename: str = 'checkpoint.pth' +) -> None: + """ + 保存模型检查点 + + Args: + state: 包含模型状态、优化器状态等的字典 + is_best: 是否为最佳模型 + output_dir: 输出目录 + filename: 检查点文件名 + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # 保存检查点 + filepath = os.path.join(output_dir, filename) + torch.save(state, filepath) + logger.info(f'Saved checkpoint to {filepath}') + + # 如果是最佳模型,保存一个副本 + if is_best: + best_filepath = os.path.join(output_dir, 'model_best.pth') + shutil.copyfile(filepath, best_filepath) + logger.info(f'Saved best model to {best_filepath}') + +def load_checkpoint( + checkpoint_path: str, + model: Optional[torch.nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None +) -> Dict: + """ + 加载模型检查点 + + Args: + checkpoint_path: 检查点文件路径 + model: 要加载权重的模型 + optimizer: 要加载状态的优化器 + scheduler: 要加载状态的学习率调度器 + + Returns: + 包含检查点内容的字典 + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f'Checkpoint not found at {checkpoint_path}') + + logger.info(f'Loading checkpoint from {checkpoint_path}') + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # 加载模型权重 + if model is not None and 'model_state_dict' in checkpoint: + try: + model.load_state_dict(checkpoint['model_state_dict']) + logger.info('Successfully loaded model weights') + except Exception as e: + logger.error(f'Error loading model weights: {e}') + raise + + # 加载优化器状态 + if optimizer is not None and 'optimizer_state_dict' in checkpoint: + try: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + logger.info('Successfully loaded optimizer state') + except Exception as e: + logger.error(f'Error loading optimizer state: {e}') + raise + + # 加载学习率调度器状态 + if scheduler is not None and 'scheduler_state_dict' in checkpoint: + try: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + logger.info('Successfully loaded scheduler state') + except Exception as e: + logger.error(f'Error loading scheduler state: {e}') + raise + + return checkpoint + +def get_last_checkpoint(output_dir: str) -> Optional[str]: + """ + 获取最新的检查点文件路径 + + Args: + output_dir: 输出目录 + + Returns: + 最新的检查点文件路径,如果没有找到则返回None + """ + if not os.path.exists(output_dir): + return None + + checkpoints = [f for f in os.listdir(output_dir) if f.endswith('.pth')] + if not checkpoints: + return None + + # 按修改时间排序,返回最新的检查点 + checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x))) + return os.path.join(output_dir, checkpoints[-1]) + +def load_pretrained_weights( + model: torch.nn.Module, + pretrained_path: str, + strict: bool = True +) -> None: + """ + 加载预训练权重 + + Args: + model: 要加载权重的模型 + pretrained_path: 预训练权重文件路径 + strict: 是否严格匹配权重 + """ + if not os.path.exists(pretrained_path): + raise FileNotFoundError(f'Pretrained weights not found at {pretrained_path}') + + logger.info(f'Loading pretrained weights from {pretrained_path}') + state_dict = torch.load(pretrained_path, map_location='cpu') + + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + + try: + model.load_state_dict(state_dict, strict=strict) + logger.info('Successfully loaded pretrained weights') + except Exception as e: + logger.error(f'Error loading pretrained weights: {e}') + raise + +def save_training_state( + state: Dict, + output_dir: str, + filename: str = 'training_state.pth' +) -> None: + """ + 保存训练状态(不包含模型权重) + + Args: + state: 包含训练状态的字典 + output_dir: 输出目录 + filename: 文件名 + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + filepath = os.path.join(output_dir, filename) + torch.save(state, filepath) + logger.info(f'Saved training state to {filepath}') + +def cleanup_checkpoints(output_dir: str, keep_last_n: int = 5) -> None: + """ + 清理旧的检查点文件,只保留最新的N个 + + Args: + output_dir: 输出目录 + keep_last_n: 保留最新的检查点数量 + """ + if not os.path.exists(output_dir): + return + + checkpoints = [f for f in os.listdir(output_dir) if f.endswith('.pth')] + if len(checkpoints) <= keep_last_n: + return + + # 按修改时间排序 + checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x))) + + # 删除旧的检查点 + for checkpoint in checkpoints[:-keep_last_n]: + filepath = os.path.join(output_dir, checkpoint) + os.remove(filepath) + logger.info(f'Removed old checkpoint: {filepath}') \ No newline at end of file From ce7045930233c1d8d3046507ce7160a52c67f50b Mon Sep 17 00:00:00 2001 From: ZhaoTianyou Date: Thu, 6 Mar 2025 16:58:12 +0800 Subject: [PATCH 3/3] Modify the code for the model loading section of the test dataset. --- tools/test_cls.py | 81 +++++++++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/tools/test_cls.py b/tools/test_cls.py index e73a17d..46d1f7b 100644 --- a/tools/test_cls.py +++ b/tools/test_cls.py @@ -1,57 +1,78 @@ -import os -from tqdm import tqdm -import time +import os # 导入os库,用于文件和目录操作 +from tqdm import tqdm # 导入tqdm库,用于显示进度条 +import time # 导入time库,用于时间相关的操作 -import torch -from torch.utils.data import DataLoader +import torch # 导入PyTorch库,用于深度学习模型的定义和训练 +from torch.utils.data import DataLoader # 从PyTorch中导入DataLoader,用于加载数据 -from FMF.models import build_model -from FMF.datasets import build_dataset +from FMF.models import build_model # 从FMF.models模块导入build_model函数,用于构建模型 +from FMF.datasets import build_dataset # 从FMF.datasets模块导入build_dataset函数,用于构建数据集 -from FMF.utils.parser import parse_args, load_config -from FMF.utils.others import get_time -from FMF.utils.metrics import ClassificationMetric +from FMF.utils.parser import parse_args, load_config # 从FMF.utils.parser模块导入parse_args和load_config函数,用于解析命令行参数和加载配置 +from FMF.utils.others import get_time # 从FMF.utils.others模块导入get_time函数,用于获取当前时间 +from FMF.utils.metrics import ClassificationMetric # 从FMF.utils.metrics模块导入ClassificationMetric类,用于计算分类指标 def main(): + """ + 主函数,执行测试流程。 + """ + # 解析命令行参数 args = parse_args() + # 根据命令行参数加载配置 cfg = load_config(args) + # 检查预训练模型是否存在,如果不存在则抛出异常 assert os.path.exists(cfg.MODEL.PRETRAINED), f'{cfg.MODEL.PRETRAINED}不存在!' - + # 设置设备(GPU或CPU) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + # 加载测试集 print('[{}] Loading test set...'.format(get_time())) val_set = build_dataset(name=cfg.TEST.DATASET, cfg=cfg, split='test') val_loader = DataLoader(val_set, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=cfg.DATA_LOADER.NUM_WORKERS, pin_memory=cfg.DATA_LOADER.PIN_MEMORY) - + # 构建模型 print('[{}] Constructing model...'.format(get_time())) + model = build_model(cfg) - model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED)) - model.to(device) + # 加载预训练模型的权重 + + # model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED)) + checkpoint = torch.load(cfg.MODEL.PRETRAINED) + if 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + else: + model.load_state_dict(checkpoint) - test_metric = ClassificationMetric(numClass=cfg.MODEL.NUM_CLASSES) + # 将模型移动到指定设备(GPU或CPU) + model.to(device) + numClass = cfg.MODEL.NUM_CLASSES + # 初始化分类指标计算工具 + test_metric = ClassificationMetric(numClass) + # 将模型设置为评估模式 model.eval() + # 开始测试 print('[{}] Testing...'.format(get_time())) - with torch.no_grad(): - start_time = time.time() - for i, xy, in enumerate(tqdm(val_loader, ncols=100)): - y = xy[-1] - x = [ipt.to(device) for ipt in xy[:-1]] - y_hat = model(*x) + with torch.no_grad(): # 禁用梯度计算,以节省内存和计算资源 + start_time = time.time() # 记录测试开始时间 + for i, xy in enumerate(tqdm(val_loader, ncols=100)): # 遍历测试集,并显示进度条 + y = xy[-1] # 获取标签 + x = [ipt.to(device) for ipt in xy[:-1]] # 将输入数据移动到设备 + y_hat = model(*x) # 前向传播,获取模型的预测结果 + # 将预测结果和真实标签添加到分类指标计算工具中 test_metric.addBatch(torch.argmax(y_hat.detach().cpu(), dim=1).numpy(), y.cpu().numpy()) - end_time = time.time() - fps = len(val_set) / (end_time - start_time) + end_time = time.time() # 记录测试结束时间 + fps = len(val_set) / (end_time - start_time) # 计算每秒帧数(FPS) + # 打印测试结果 print('[{}] Tests complete!'.format(get_time())) - print('ACC: {}'.format(test_metric.Accuracy())) - print('F1 : {}'.format(test_metric.F1Score())) - print('FDR: {}'.format(test_metric.FalsePositiveRate())) - print('MDR: {}'.format(test_metric.FalseNegativeRate())) - print('FPS: {}'.format(fps)) + print('ACC: {}'.format(test_metric.Accuracy())) # 打印准确率 + print('F1 : {}'.format(test_metric.F1Score())) # 打印F1分数 + print('FDR: {}'.format(test_metric.FalsePositiveRate())) # 打印假阳性率 + print('MDR: {}'.format(test_metric.FalseNegativeRate())) # 打印假阴性率 + print('FPS: {}'.format(fps)) # 打印每秒帧数 if __name__ == '__main__': - main() + main() # 执行主函数