Abstract:This study presents a new algorithm called TripletMAML which extends the Model-Agnostic Meta-Learning (MAML) algorithm from a metric-learning perspective. The same optimization procedure of MAML is adopted, but the neural network model is replaced with a triplet network which enables the utilization of metric-learning through the embeddings. In order to be able to incorporate the metric loss during meta-learning, we have developed a triplet-task generation scheme that creates tasks consisting of triplets for both
- Title:
TripletMAML: A metric-based model-agnostic meta-learning algorithm for few-shot classification - Authors:
Ayla Gülcü,Zeki Kuş,İsmail Taha Samed Özkan,Osman Furkan Karakuş
- python == 3.9.7
- learn2learn == 0.1.7
- numpy = 1.20.3
- pytorch == 1.10.2 and py3.9_cuda11.3_cudnn8.2.0_0
- scikit-learn == 1.2.0
| Dataset | Download | Download Needed |
|---|---|---|
| Omniglot | download | NO |
| MiniImageNet | download | NO |
| CUB-200-2011 | download | YES |
| CIFAR Few-Shot | download | YES |
- for Omniglot
TripletOmniglot.py Files will handle downloading automatically if download mode is set as True. - for MiniImageNet
./data/MiniImageNet/MiniImageNet_downloader.py will handle downloading test-train-validation files accordingly. No needed to install another dataset. - for CUB-200-2011
CUB-200-2011 files are needed to be downloaded from given link. Then replaced into corresponding folder which is ./data/CUB/ Choose the generator of your preference so they can create the test-train-validation files accordingly. - for CIFAR Few-Shot
CIFAR Few-Shot files are needed to be downloaded from given link. Then replaced into corresponding folder which is ./data/CIFARFS100/ First run the processor then you can use generator to create the test-train-validation files accordingly.
- for classification
please look into ./TripletMAML/maml_triplet_train_test_val.py - for retrieval
please look into ./TripletMAML/maml_triplet_test_retrieval.py
|—— .gitignore
|—— data
| |—— Generators
| |—— CIFARFS100
| |—— CIFARFS100_generator.py
| |—— CIFARFS100_processor.py
| |—— CUB
| |—— CUB_BB_NoResize_generator.py
| |—— CUB_BB_Resize_generator.py
| |—— CUB_NoBB_generator.py
| |—— Flowers
| |—— Flowers_generator.py
| |—— MiniImageNet
| |—— MiniImageNet_downloader.py
|—— HPO
| |—— backbone.py
| |—— de.py
| |—— losses.py
| |—— model.py
| |—— rs.py
| |—— train.py
| |—— Triplets
| |—— TripletCUB.py
| |—— TripletFlowers.py
| |—— TripletFSCIFAR100.py
| |—— TripletMiniImageNet.py
| |—— TripletOmniglot.py
| |—— __init__.py
|—— TripletMAML
| |—— backbone.py
| |—— losses.py
| |—— maml_triplet_test_retrieval.py
| |—— maml_triplet_train_test_val.py
| |—— TaskControl.ipynb
| |—— Triplets
| |—— TripletCUB.py
| |—— TripletFlowers.py
| |—— TripletFSCIFAR100.py
| |—— TripletMiniImageNet.py
| |—— TripletOmniglot.py
| |—— __init__.py
- software
OS: Ubuntu 20.04 LTS Python: 3.9.7 (anaconda) - hardware
CPU: Intel(R) Core(TM) i9-10850k CPU @3.60 GHz GPU: Nvidia RTX3090 (24GB)
