-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.cpp
More file actions
160 lines (134 loc) · 4.53 KB
/
Copy pathmain.cpp
File metadata and controls
160 lines (134 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/script.h> // One-stop header.
#include <cstdio>
#include "cxxopts.hpp"
#include <iostream>
#include <opencv2/opencv.hpp>
#include <opencv2/highgui/highgui.hpp>
inline bool file_exists(const std::string &name) {
std::cout << "checking file exists " << name << std::endl;
std::ifstream f(name.c_str());
return f.good();
}
cxxopts::ParseResult parse(int argc, char *argv[]) {
try {
cxxopts::Options options(argv[0], " - SR command line options");
options.add_options()
("help", "Print help")
("w, weights", "weights file path", cxxopts::value<std::string>())
("i, input", "input image file path", cxxopts::value<std::string>())
("o, output", "output image file path", cxxopts::value<std::string>());
auto result = options.parse(argc, argv);
if (result.count("help") || result.arguments().empty()) {
std::cout << options.help() << std::endl;
exit(0);
}
bool missing = false;
for (auto o : { "w", "i", "o" }) {
if (result.count(o) == 0) {
std::cerr << "missing arg " << o << std::endl;
missing = true;
}
}
if (missing) exit(-1);
missing = false;
for (auto o : { "w", "i" }) {
if (!file_exists(result[o].as<std::string>())) {
std::cerr << "missing file " << o << std::endl;
missing = true;
}
}
if (missing) exit(-1);
std::cout << "weights = " << result["weights"].as<std::string>()
<< std::endl;
std::cout << "input = " << result["input"].as<std::string>()
<< std::endl;
std::cout << "output = " << result["output"].as<std::string>()
<< std::endl;
return result;
}
catch (const cxxopts::OptionException &e) {
std::cout << "error parsing options: " << e.what() << std::endl;
exit(1);
}
}
void check_cuda() {
int count = 0;
if (cudaGetDeviceCount(&count) == cudaError::cudaSuccess) {
std::printf("%d.%d", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);
if (count == 0) {
std::cerr << "couldn't get number of gpus";
exit(-1);
}
}
else {
std::cerr << "couldn't get cuda device count";
exit(-1);
}
}
at::Tensor cv2_to_torch(cv::Mat frame) {
frame.convertTo(frame, CV_32FC1, 1.0f / 255.0f);
at::Tensor input_tensor = torch::from_blob(
frame.ptr<float>(),
{ 1, frame.size().height, frame.size().width, frame.channels() }
);
input_tensor = input_tensor.permute({ 0, 3, 1, 2 });
return input_tensor.clone();
}
cv::Mat cv2_image(const std::string &fp) {
cv::Mat image = imread(fp, cv::IMREAD_GRAYSCALE);
return image;
}
void display_cv_image(cv::Mat image) {
namedWindow("Display window", cv::WINDOW_AUTOSIZE);
imshow("Display window", image);
cv::waitKey(0);
}
cv::Mat torch_to_cv2(at::Tensor tensor) {
tensor = tensor.detach().permute({ 1, 2, 0 }); // detach()
tensor = tensor.mul(255).clamp(0, 255).to(torch::kU8);
tensor = tensor.to(torch::kCPU);
cv::Mat result_img(tensor.size(0), tensor.size(1), CV_8UC1);
// -------------------------------------------------------------------------------------- //
std::memcpy((void *)result_img.data, tensor.data_ptr(), sizeof(torch::kU8) * tensor.numel());
return result_img;
// --------------------------------------------------------------------------------------- //
}
int main(int argc, char *argv[]) {
torch::manual_seed(1);
check_cuda();
auto result = parse(argc, argv);
const auto &arguments = result.arguments();
torch::jit::script::Module model;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
model = torch::jit::load(result["weights"].as<std::string>());
model.to(at::kCUDA);
}
catch (const c10::Error &e) {
std::cerr << "error loading the model\n";
std::cerr << result["weights"].as<std::string>();
return -1;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
cv::Mat img = cv2_image(result["input"].as<std::string>());
std::cout << img.size() << std::endl;
//image.reshape( 1, image.cols * image.rows );
at::Tensor t_img = cv2_to_torch(img); // cv to torch code 구현
t_img = t_img.to(at::kCUDA); // input에 cuda를 붙여줘야 한다. 불일치시 에러
inputs.emplace_back(t_img);
at::Tensor output = model.forward(inputs).toTensor();
cv::Mat sr_img = torch_to_cv2(output[0]);
display_cv_image(sr_img);
cv::imwrite(result["output"].as<std::string>(), sr_img);
}
//for (const auto & pair : model.named_parameters()) {
// //pair.value
// std::cout<< pair.name<<", " <<pair.value.sizes()<< std::endl;
//}
// Execute the model and turn its output into a tensor.
//auto data = batch.data.to(device), targets = batch.target.to(device);
//optimizer.zero_grad();
// auto output = model.forward(data);