forked from cosmic-cortex/pytorch-UNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
23 lines (16 loc) · 693 Bytes
/
predict.py
File metadata and controls
23 lines (16 loc) · 693 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
from argparse import ArgumentParser
from unet.model import Model
from unet.dataset import Image2D
parser = ArgumentParser()
parser.add_argument('--dataset', required=True, type=str)
parser.add_argument('--results_path', required=True, type=str)
parser.add_argument('--model_path', required=True, type=str)
parser.add_argument('--device', default='cpu', type=str)
args = parser.parse_args()
predict_dataset = Image2D(args.dataset)
model = torch.load(args.model_path)
if not os.path.exists(args.results_path):
os.makedirs(args.results_path)
model = Model(unet, checkpoint_folder=args.results_path, device=args.device)
model.predict_dataset(predict_dataset, args.result_path)