Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graph_framework.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,7 @@
"-lLLVMSandboxIR",
"-lLLVMObjectYAML",
"-lLLVMPlugins",
"-lLLVMABI",
"-lLLVMFrontendAtomic",
"-lclangFrontend",
"-lclangBasic",
Expand Down Expand Up @@ -2301,6 +2302,7 @@
"-lLLVMSandboxIR",
"-lLLVMObjectYAML",
"-lLLVMPlugins",
"-lLLVMABI",
"-lLLVMFrontendAtomic",
"-lclangFrontend",
"-lclangBasic",
Expand Down
126 changes: 89 additions & 37 deletions graph_framework/piecewise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,40 @@ void compile_index(std::ostringstream &stream,
const size_t length,
const T scale,
const T offset) {
const std::string type = jit::smallest_int_type<T> (length);
stream << "min(max(("
<< type
<< ")";
const std::string type = jit::type_to_string<T> ();
stream << "(" << jit::smallest_uint_type<T> (length) << ")min";
if constexpr (!jit::use_metal<T> ()) {
stream << "<" << type << ">";
}
stream << "(max";
if constexpr (!jit::use_metal<T> ()) {
stream << "<" << type << ">";
}
stream << "(";
if constexpr (jit::complex_scalar<T>) {
stream << "real(";
}
stream << "((" << register_name << " - ";
stream << "(" << register_name << " - ";
if constexpr (jit::complex_scalar<T>) {
stream << jit::get_type_string<T> ();
}
stream << offset << ")/";
if constexpr (jit::complex_scalar<T>) {
stream << jit::get_type_string<T> ();
}
stream << scale << ")";
stream << scale;
if constexpr (jit::complex_scalar<T>) {
stream << ")";
}
stream << ",(" << type << ")0),("
<< type << ")" << length - 1 << ")";
stream << ",";
if constexpr (jit::use_metal<T> ()) {
stream << "(" << type << ")";
}
stream << "0),";
if constexpr (jit::use_metal<T> ()) {
stream << "(" << type << ")";
}
stream << length - 1 << ")";
}

//******************************************************************************
Expand Down Expand Up @@ -192,9 +205,17 @@ void compile_index(std::ostringstream &stream,
virtual shared_leaf<T, SAFE_MATH> reduce() {
if (constant_cast(this->arg).get()) {
const T arg = (this->arg->evaluate().at(0) + offset)/scale;
const size_t i = std::min(static_cast<size_t> (std::real(arg)),
this->get_size() - 1);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
if constexpr (jit::float_base<T>) {
const size_t i = std::max<float> (std::min<float> (std::real(arg),
this->get_size() - 1),
0);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
} else {
const size_t i = std::max<double> (std::min<double> (std::real(arg),
this->get_size() - 1),
0);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
}
}

if (evaluate().is_same()) {
Expand Down Expand Up @@ -333,7 +354,7 @@ void compile_index(std::ostringstream &stream,
#ifdef USE_INDEX_CACHE
indices[a.get()] = jit::to_string('i', a.get());
stream << " const "
<< jit::smallest_int_type<T> (length) << " "
<< jit::smallest_uint_type<T> (length) << " "
<< indices[a.get()] << " = ";
compile_index<T> (stream, registers[a.get()], length,
scale, offset);
Expand All @@ -347,7 +368,7 @@ void compile_index(std::ostringstream &stream,
stream << " " << registers[this] << " = ";
#ifdef USE_CUDA_TEXTURES
if constexpr (jit::use_cuda()) {
if constexpr (float_base<T>) {
if constexpr (jit::float_base<T>) {
if constexpr (complex_scalar<T>) {
stream << "to_cmp_float(tex1D<float2> (";
} else {
Expand Down Expand Up @@ -835,25 +856,56 @@ void compile_index(std::ostringstream &stream,
constant_cast(this->right).get()) {
const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
const size_t i = std::min(static_cast<size_t> (std::real(l)),
this->get_num_rows() - 1);
const size_t j = std::min(static_cast<size_t> (std::real(r)),
this->get_num_columns() - 1);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);

if constexpr (jit::float_base<T>) {
const size_t i = std::max<float> (std::min<float> (std::real(l),
this->get_num_rows() - 1),
0);
const size_t j = std::max<float> (std::min<float> (std::real(r),
this->get_num_columns() - 1),
0);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
} else {
const size_t i = std::max<double> (std::min<double> (std::real(l),
this->get_num_rows() - 1),
0);
const size_t j = std::max<double> (std::min<double> (std::real(r),
this->get_num_columns() - 1),
0);
return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
}
} else if (constant_cast(this->left).get()) {
const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
const size_t i = std::min(static_cast<size_t> (std::real(l)),
this->get_num_rows() - 1);

return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
this->right, y_scale, y_offset);
if constexpr (jit::float_base<T>) {
const size_t i = std::max<float> (std::min<float> (std::real(l),
this->get_num_rows() - 1),
0);
return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
this->right, y_scale, y_offset);
} else {
const size_t i = std::max<double> (std::min<double> (std::real(l),
this->get_num_rows() - 1),
0);
return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
this->right, y_scale, y_offset);
}
} else if (constant_cast(this->right).get()) {
const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
const size_t j = std::min(static_cast<size_t> (std::real(r)),
this->get_num_columns() - 1);

return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
this->left, x_scale, x_offset);

if constexpr (jit::float_base<T>) {
const size_t j = std::max<float> (std::min<float> (std::real(r),
this->get_num_columns() - 1),
0);
return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
this->left, x_scale, x_offset);
} else {
const size_t j = std::max<double> (std::min<double> (std::real(r),
this->get_num_columns() - 1),
0);
return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
this->left, x_scale, x_offset);
}
}

if (evaluate().is_same()) {
Expand Down Expand Up @@ -1015,7 +1067,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(x.get()) == indices.end()) {
indices[x.get()] = jit::to_string('i', x.get());
stream << " const "
<< jit::smallest_int_type<T> (num_rows) << " "
<< jit::smallest_uint_type<T> (num_rows) << " "
<< indices[x.get()] << " = ";
compile_index<T> (stream, registers[x.get()], num_rows,
x_scale, x_offset);
Expand All @@ -1024,7 +1076,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(y.get()) == indices.end()) {
indices[y.get()] = jit::to_string('i', y.get());
stream << " const "
<< jit::smallest_int_type<T> (num_columns) << " "
<< jit::smallest_uint_type<T> (num_columns) << " "
<< indices[y.get()] << " = ";
compile_index<T> (stream, registers[y.get()], num_columns,
y_scale, y_offset);
Expand All @@ -1040,7 +1092,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(temp.get()) == indices.end()) {
indices[temp.get()] = jit::to_string('i', temp.get());
stream << " const "
<< jit::smallest_int_type<T> (length) << " "
<< jit::smallest_uint_type<T> (length) << " "
<< indices[temp.get()] << " = "
<< indices[x.get()]
<< "*" << num_columns << " + "
Expand All @@ -1056,7 +1108,7 @@ void compile_index(std::ostringstream &stream,
stream << " " << registers[this] << " = ";
#ifdef USE_CUDA_TEXTURES
if constexpr (jit::use_cuda()) {
if constexpr (float_base<T>) {
if constexpr (jit::float_base<T>) {
if constexpr (complex_scalar<T>) {
stream << "to_cmp_float(tex1D<float2> (";
} else {
Expand All @@ -1075,8 +1127,8 @@ void compile_index(std::ostringstream &stream,
if constexpr (jit::use_metal<T> ()) {
#ifdef USE_INDEX_CACHE
stream << ".read("
<< jit::smallest_int_type<T> (std::max(num_rows,
num_columns))
<< jit::smallest_uint_type<T> (std::max(num_rows,
num_columns))
<< "2("
<< indices[y.get()]
<< ","
Expand Down Expand Up @@ -1461,7 +1513,7 @@ void compile_index(std::ostringstream &stream,
#ifdef USE_INDEX_CACHE
indices[a.get()] = jit::to_string('i', a.get());
stream << " const "
<< jit::smallest_int_type<T> (length) << " "
<< jit::smallest_uint_type<T> (length) << " "
<< indices[a.get()] << " = ";
compile_index<T> (stream, registers[a.get()], length,
scale, offset);
Expand Down Expand Up @@ -1797,7 +1849,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(x.get()) == indices.end()) {
indices[x.get()] = jit::to_string('i', x.get());
stream << " const "
<< jit::smallest_int_type<T> (num_rows) << " "
<< jit::smallest_uint_type<T> (num_rows) << " "
<< indices[x.get()] << " = ";
compile_index<T> (stream, registers[x.get()], num_rows,
x_scale, x_offset);
Expand All @@ -1806,7 +1858,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(y.get()) == indices.end()) {
indices[y.get()] = jit::to_string('i', y.get());
stream << " const "
<< jit::smallest_int_type<T> (num_columns) << " "
<< jit::smallest_uint_type<T> (num_columns) << " "
<< indices[y.get()] << " = ";
compile_index<T> (stream, registers[y.get()], num_columns,
y_scale, y_offset);
Expand All @@ -1819,7 +1871,7 @@ void compile_index(std::ostringstream &stream,
if (indices.find(temp.get()) == indices.end()) {
indices[temp.get()] = jit::to_string('i', temp.get());
stream << " const "
<< jit::smallest_int_type<T> (length) << " "
<< jit::smallest_uint_type<T> (length) << " "
<< indices[temp.get()] << " = "
<< indices[x.get()]
<< "*" << num_columns << " + "
Expand Down
2 changes: 1 addition & 1 deletion graph_framework/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ namespace jit {
/// @returns The smallest integer type as a string.
//------------------------------------------------------------------------------
template<float_scalar T>
std::string smallest_int_type(const size_t max_size) {
std::string smallest_uint_type(const size_t max_size) {
if (max_size <= std::numeric_limits<unsigned char>::max()) {
if constexpr (jit::use_metal<T> ()) {
return "ushort";
Expand Down
Loading