-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargs.py
More file actions
99 lines (89 loc) · 8.59 KB
/
Copy pathargs.py
File metadata and controls
99 lines (89 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
def parse_args():
parser = argparse.ArgumentParser(description="GeospatialFM Training Arguments")
# Dataset arguments
parser.add_argument("--dataset_name", type=str, choices=["ssl4eo", "enmap"], help="Dataset to run")
parser.add_argument("--data_dir", type=str, required=True, help="Path to the SSL4EO dataset")
parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of subprocesses to use for data loading")
parser.add_argument("--dataloader_pin_memory", action="store_true", help="Whether to pin memory for data loading")
parser.add_argument("--use_8bit", action="store_true", help="Whether to use 8-bit data loading")
parser.add_argument("--crop_size", type=int, default=None, help="Crop size for training")
parser.add_argument("--scale", type=float, default=None, help="Scale for training")
parser.add_argument("--gen_task", type=str, default=None, choices=["id", "ood_a", "ood_full", "ood_complement"], help="Generalization task")
parser.add_argument("--enmap_subset", type=int, default=None, choices=[120, 60, 10])
# Model arguments
parser.add_argument("--patch_size", type=int, default=16, help="Size of patches for hyperspectral patch embedding")
parser.add_argument("--in_channels", type=int, default=None, help="Number of hyperspectral input channels. If not set, infer from dataset metadata when possible")
parser.add_argument("--embed_dim", type=int, default=768, help="Embedding dimension")
parser.add_argument("--channel_embed_dims_per_head", type=int, default=4, help="Number of channel embedding dimensions per head")
parser.add_argument("--depth", type=int, default=12, help="Number of transformer layers")
parser.add_argument("--num_heads", type=int, default=12, help="Number of attention heads")
parser.add_argument("--mlp_ratio", type=float, default=4.0, help="MLP ratio")
parser.add_argument("--qkv_bias", type=bool, default=True, help="Use bias in qkv projections")
parser.add_argument("--qk_norm", type=bool, default=False, help="Use qk normalization")
parser.add_argument("--drop_path_rate", type=float, default=0.0, help="Drop path rate")
parser.add_argument("--drop_path_uniform", type=bool, default=False, help="Use uniform drop path")
parser.add_argument("--init_values", type=float, default=None, help="Init values for LayerScale")
parser.add_argument("--attn_drop", type=float, default=0.0, help="Attention dropout rate")
parser.add_argument("--proj_drop", type=float, default=0.0, help="Projection dropout rate")
parser.add_argument("--mask_ratio", type=float, default=0.75, help="Mask ratio for MAE")
parser.add_argument("--channel_mask_ratio", type=float, default=0.5, help="Channel mask ratio for MAE")
parser.add_argument("--rank", type=int, default=1, help="Rank for LowRankAttention")
# AdditionalModel arguments
parser.add_argument("--return_dict", type=bool, default=False, help="Return a dictionary instead of a tuple")
parser.add_argument("--norm_pix_loss", action="store_true", help="Whether to normalize pixel loss")
parser.add_argument("--use_perception_field_mask", action="store_true", help="Use perception field mask")
parser.add_argument("--attention_radius", type=int, default=640, help="Attention radius for perception field mask")
parser.add_argument("--loss_type", type=str, default="mse", help="Type of loss function")
parser.add_argument("--random_crop", action="store_true", help="Use random crop")
parser.add_argument("--use_rope_embed", action="store_true", help="Use RoPe positional embedding")
parser.add_argument("--rope_embed_base", type=float, default=100.0, help="Base value for RoPe positional embedding")
parser.add_argument("--channel_dropout", type=float, nargs='+', default=None, help="Channel dropout rate for training")
# Task Model arguments
parser.add_argument("--task_type", type=str, choices=["classification", "multilabel", "segmentation", "Pretraining"], default="Pretraining", help="Task type")
parser.add_argument("--num_experts", type=int, default=None, help="Number of experts, -1 for all channels, None for CLS token only")
parser.add_argument("--topk", type=int, default=3, help="Top-k for MoE, -1 for all channels,")
parser.add_argument("--use_moe", action="store_true", help="Use MoE")
# Decoder arguments
parser.add_argument("--decoder_embed_dim", type=int, default=512, help="Embedding dimension for decoder")
parser.add_argument("--decoder_depth", type=int, default=8, help="Number of transformer layers for decoder")
parser.add_argument("--decoder_num_heads", type=int, default=16, help="Number of attention heads for decoder")
parser.add_argument("--decoder_out_chans", type=int, default=1, help="Number of output channels for decoder")
parser.add_argument("--decoder_out_dims", type=int, default=1, help="Number of output dimensions for decoder")
parser.add_argument("--decoder_channel_embed_dims_per_head", type=int, default=4, help="Number of channel embedding dimensions per head for decoder")
# Trainer arguments
parser.add_argument("--run_name", type=str, required=True, help="Name of the run")
parser.add_argument("--per_device_train_batch_size", type=int, default=32, help="Batch size for training")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam optimizer beta1")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam optimizer beta2")
parser.add_argument("--weight_decay", type=float, default=0.05, help="Optimizer weight decay")
parser.add_argument("--adam_epsilon", type=float, default=1e-8, help="Adam optimizer epsilon")
parser.add_argument("--max_steps", type=int, default=-1, help="Total number of training steps")
parser.add_argument("--num_train_epochs", type=int, default=100, help="Total number of training epochs")
parser.add_argument("--lr_scheduler_type", type=str, default="cosine", help="Type of learning rate scheduler")
parser.add_argument("--warmup_steps", type=int, default=0, help="Number of warmup steps for learning rate scheduler")
parser.add_argument("--warmup_ratio", type=float, default=0.2, help="Warmup ratio for learning rate scheduler")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass")
parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether to use gradient checkpointing to save memory at the expense of slower backward pass")
parser.add_argument("--max_grad_norm", type=float, default=None, help="Max gradient norm for gradient clipping")
parser.add_argument("--model_name", type=str, default="lessvit", choices=["lessvit", "specvit", "dinov3", "dofa", "spatsigma"], help="Model to train")
# Logging and saving arguments
parser.add_argument("--output_dir", type=str, default="output", help="Directory to save model checkpoints and logs")
parser.add_argument("--logging_dir", type=str, default="logs", help="Directory to save logs")
parser.add_argument("--report_to", type=str, default="wandb", help="Where to report results to (tensorboard, wandb, etc.)")
parser.add_argument("--save_strategy", type=str, default="epoch", help="Save strategy")
parser.add_argument("--save_steps", type=float, default=500, help="Save checkpoint every X updates steps")
parser.add_argument("--save_total_limit", type=int, default=None, help="If set, deletes the older checkpoints in output_dir")
parser.add_argument("--wandb_dir", type=str, default="wandb", help="Directory to save wandb logs")
# Other arguments
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training")
parser.add_argument("--mixed_precision", type=str, default=None, choices=[None, "fp16", "bf16"], help="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10 and an Nvidia Ampere GPU")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder")
parser.add_argument("--modal_mode", type=str, default=None, choices=["multi", "optical", "radar", None, "random"], help="Mode of the model, if None, the model is trained in three modes simultaneously")
# Append run name to directories
args = parser.parse_args()
args.output_dir = os.path.join(args.output_dir, args.run_name)
args.logging_dir = os.path.join(args.logging_dir, args.run_name)
return args