-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference.py
More file actions
51 lines (40 loc) · 2.31 KB
/
inference.py
File metadata and controls
51 lines (40 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import sys
from utils import getPretrainedModel, get_neighbors, get_images_path, \
getTrainingEmbeddingsFromFile, getEmbeddingsforImagePathList, plot_canvas
from config import CFG
from arcface import get_test_transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def main():
# Training Mode off
CFG.isTraining = False
train_df = pd.read_csv(CFG.DATA_DIR)
# Load Trained Model
shopee_model = getPretrainedModel(loss_module='ArcFace', model_path=CFG.model_path_arcface, device=CFG.device)
# we already stored Training Embeddings in files, so Let's use it
training_embeddings = getTrainingEmbeddingsFromFile(CFG.embedding_path)
# If you don't have training embedding saved and want to generate it uncomment below line
# train_image_paths = get_images_path(CFG.TRAIN_DIR)
# training_embeddings = getEmbeddingsforImagePathList(queryImagesPath=train_image_paths, model=shopee_model,
# transform=get_test_transforms())
# get Query image path
query_images_path = get_images_path(CFG.TEST_DIR)
# get query embeddings
query_emebddings = getEmbeddingsforImagePathList(queryImagesPath=query_images_path, model=shopee_model,
transform=get_test_transforms())
# get cosine distance and cosine indices to top 50 neighbors
query_cosine_distances, query_cosine_indices = get_neighbors(train_embeddings=training_embeddings,
query_embeddings=query_emebddings,
KNN=50, metric_param='cosine')
indices = [i for i in range(6)]
for k in indices:
plt.figure(figsize=(20, 3))
plt.plot(np.arange(8), query_cosine_distances[k,][:8], 'o-')
plt.title('Image {} Distance From Train Row {} to Other Train Rows'.format("cosine", k), size=16)
plt.ylabel('{} Distance to Train Row {}'.format("cosine", k), size=14)
plt.xlabel('Index Sorted by {} Distance to Train Row {}'.format("cosine", k), size=14)
cluster = train_df.loc[query_cosine_indices[k, :5]]
fig = plot_canvas(cluster, COLS=5, ROWS=1, path=CFG.TRAIN_DIR + "/", img_list=query_images_path, k=k)
if __name__ == '__main__':
main()