-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprint_test_result.py
More file actions
executable file
·145 lines (144 loc) · 5.87 KB
/
print_test_result.py
File metadata and controls
executable file
·145 lines (144 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/python3
'''
Abstract:
This is a program to show the basic result of AI testing.
Usage:
print_test_result.py [main_name fo test set]
Example:
print_test_result.py MaxLoss15
Editor:
Jacob975
##################################
# Python3 #
# This code is made in python3 #
##################################
20180430
####################################
update log
20180430 version alpha 1:
1. The code work
20180601 version alpha 2:
1. add func to print reliable data out
20190529 version alpha 3:
1. Using main name to replace the keyword.
'''
import numpy as np
import time
from load_lib import confusion_matrix_infos, load_arrangement, load_labels_pred, load_cls_true
from sys import argv
from glob import glob
import os
#--------------------------------------------
# main code
if __name__ == "__main__":
VERBOSE = 0
# measure times
start_time = time.time()
#----------------------------------------
# Initialize variables and constants
data = None
tracer = None
cls_pred = None
cls_true = None
#----------------------------------------
# Load argv
if len(argv) != 2:
print ("Error!\nUsage: print_test_result.py [main name of the test set]")
exit()
main_name = argv[1]
#----------------------------------------
# Load data
print (os.getcwd())
data_list = glob("AI*test_on*{0}".format(main_name))
ensemble_cls_true = None
labels_pred_set = []
for directory in data_list:
print ("#################################")
print ("start to loading data saved in {0}".format(directory))
# load tracer
failure, data, tracer = load_arrangement(main_name, directory)
if not failure:
print ("load data and tracer success")
# load label_pred
failure, labels_pred = load_labels_pred(main_name, directory)
if not failure:
print ("load labels_pred success")
temp_labels_pred = [ value for _,value in sorted(zip(tracer.test, labels_pred))]
labels_pred_set.append(temp_labels_pred)
# load cls_true
failure, cls_true = load_cls_true(main_name, directory)
if not failure:
print ("load cls_true success")
if ensemble_cls_true == None:
ensemble_cls_true = [ value for _,value in sorted(zip(tracer.test, cls_true))]
#-----------------------------------
# print the properties of sources
infos = confusion_matrix_infos(cls_true, labels_pred)
print("### sources in dataset ### ")
try:
star_length = len(infos.cls_true[infos.cls_true == 0])
except:
star_length = 0
print ("number of stars: {0}".format(star_length))
try:
galaxy_length = len(infos.cls_true[infos.cls_true == 1])
except:
galaxy_length = 0
print ("number of galaxies: {0}".format(galaxy_length))
try:
yso_length = len(infos.cls_true[infos.cls_true == 2])
except:
yso_length = 0
print ("number of ysos: {0}".format(yso_length))
# print the properties of predictions
failure, cm, cm_reliable = infos.confusion_matrix()
print("confusion matrix:\n{0}".format(cm))
if infos.reliable:
print("### reliable sources in dataset ### ")
star_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 0])
print ("number of stars: {0}".format(star_length))
galaxy_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 1])
print ("number of galaxies: {0}".format(galaxy_length))
yso_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 2])
print ("number of ysos: {0}".format(yso_length))
print("reliable confusion matrix:\n{0}".format(cm_reliable))
infos.print_accuracy()
infos.print_precision()
infos.print_recall_rate()
#-----------------------------------
# print the result of ensemble results
labels_pred_set = np.array(labels_pred_set)
ensemble_labels_pred = np.mean(labels_pred_set, axis = 0)
ensemble_cls_true = np.array(ensemble_cls_true)
infos = confusion_matrix_infos(ensemble_cls_true, ensemble_labels_pred)
print ("\n#################################")
print ("### prediction of ensemble AI ###")
print ("#################################")
print ("### sources in dataset ### ")
star_length = len(infos.cls_true[infos.cls_true == 0])
print ("number of stars: {0}".format(star_length))
galaxy_length = len(infos.cls_true[infos.cls_true == 1])
print ("number of galaxies: {0}".format(galaxy_length))
yso_length = len(infos.cls_true[infos.cls_true == 2])
print ("number of ysos: {0}".format(yso_length))
print("### reliable sources in dataset ### ")
star_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 0])
print ("number of stars: {0}".format(star_length))
galaxy_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 1])
print ("number of galaxies: {0}".format(galaxy_length))
yso_length = len(infos.cls_true_reliable[infos.cls_true_reliable == 2])
print ("number of ysos: {0}".format(yso_length))
# print the properties of predictions
failure, cm, cm_reliable = infos.confusion_matrix()
print("confusion matrix:\n{0}".format(cm))
print("reliable confusion matrix:\n{0}".format(cm_reliable))
infos.print_accuracy()
infos.print_precision()
infos.print_recall_rate()
cls_pred = np.argmax(infos.labels_pred, axis = 1)
np.savetxt("{0}_label_pred.txt".format(main_name), infos.labels_pred)
np.savetxt("{0}_cls_pred.txt".format(main_name), cls_pred)
#----------------------------------------
# measuring time
elapsed_time = time.time() - start_time
print ("Exiting Main Program, spending ", elapsed_time, "seconds.")