-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml.py
More file actions
46 lines (36 loc) · 1.4 KB
/
Copy pathml.py
File metadata and controls
46 lines (36 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from joblib import dump, load
from sklearn.datasets import fetch_openml
import numpy as np
# from sklearn.svm import SVC
# from sklearn.linear_model import SGDClassifier
# mnist = fetch_openml("mnist_784", as_frame=False)
# X,y = mnist.data, mnist.target
# class Model:
# """Model Class for predictions"""
# def __init__(self):
# svm_clf = SVC(random_state=42)
# sgd_clf = SGDClassifier(random_state=42)
# # svm_clf.fit(X[:3000], y[:3000])
# sgd_clf.fit(X[:59000], y[:59000])
# # self.model = joblib.load("my_classifier_model.pkl")
# # self.model = svm_clf
# self.model = sgd_clf
# def predict_input(self, arr):
# output = self.model.predict([arr.reshape(784)])
# # print(output)
# return output
# def save_model(self):
# dump(self.model, 'sgd_clf.joblib')
# my_model = Model()
# any_num = X[0]
# my_model.predict_input(any_num)
# print("Shape:", any_num.shape)
# def predict(arr, model=None):
# """Predict function for passing array and model and it returns output ie prediction."""
# output = model.predict([arr.reshape(784)])
# return output
def predict(arr, model=None):
"""Predict function for passing array and model and it returns output ie prediction."""
output = model.predict([arr.reshape(1,28,28)])
prediction = np.argmax(output)
return prediction