diff --git a/graph_framework.xcodeproj/project.pbxproj b/graph_framework.xcodeproj/project.pbxproj index 8afd529..e6213ad 100644 --- a/graph_framework.xcodeproj/project.pbxproj +++ b/graph_framework.xcodeproj/project.pbxproj @@ -2126,6 +2126,7 @@ "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMPlugins", + "-lLLVMABI", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", @@ -2301,6 +2302,7 @@ "-lLLVMSandboxIR", "-lLLVMObjectYAML", "-lLLVMPlugins", + "-lLLVMABI", "-lLLVMFrontendAtomic", "-lclangFrontend", "-lclangBasic", diff --git a/graph_framework/piecewise.hpp b/graph_framework/piecewise.hpp index b385853..56de5c3 100644 --- a/graph_framework/piecewise.hpp +++ b/graph_framework/piecewise.hpp @@ -28,14 +28,20 @@ void compile_index(std::ostringstream &stream, const size_t length, const T scale, const T offset) { - const std::string type = jit::smallest_int_type (length); - stream << "min(max((" - << type - << ")"; + const std::string type = jit::type_to_string (); + stream << "(" << jit::smallest_uint_type (length) << ")min"; + if constexpr (!jit::use_metal ()) { + stream << "<" << type << ">"; + } + stream << "(max"; + if constexpr (!jit::use_metal ()) { + stream << "<" << type << ">"; + } + stream << "("; if constexpr (jit::complex_scalar) { stream << "real("; } - stream << "((" << register_name << " - "; + stream << "(" << register_name << " - "; if constexpr (jit::complex_scalar) { stream << jit::get_type_string (); } @@ -43,12 +49,19 @@ void compile_index(std::ostringstream &stream, if constexpr (jit::complex_scalar) { stream << jit::get_type_string (); } - stream << scale << ")"; + stream << scale; if constexpr (jit::complex_scalar) { stream << ")"; } - stream << ",(" << type << ")0),(" - << type << ")" << length - 1 << ")"; + stream << ","; + if constexpr (jit::use_metal ()) { + stream << "(" << type << ")"; + } + stream << "0),"; + if constexpr (jit::use_metal ()) { + stream << "(" << type << ")"; + } + stream << length - 1 << ")"; } //****************************************************************************** @@ -192,9 +205,17 @@ void compile_index(std::ostringstream &stream, virtual shared_leaf reduce() { if (constant_cast(this->arg).get()) { const T arg = (this->arg->evaluate().at(0) + offset)/scale; - const size_t i = std::min(static_cast (std::real(arg)), - this->get_size() - 1); - return constant (leaf_node::caches.backends[data_hash][i]); + if constexpr (jit::float_base) { + const size_t i = std::max (std::min (std::real(arg), + this->get_size() - 1), + 0); + return constant (leaf_node::caches.backends[data_hash][i]); + } else { + const size_t i = std::max (std::min (std::real(arg), + this->get_size() - 1), + 0); + return constant (leaf_node::caches.backends[data_hash][i]); + } } if (evaluate().is_same()) { @@ -333,7 +354,7 @@ void compile_index(std::ostringstream &stream, #ifdef USE_INDEX_CACHE indices[a.get()] = jit::to_string('i', a.get()); stream << " const " - << jit::smallest_int_type (length) << " " + << jit::smallest_uint_type (length) << " " << indices[a.get()] << " = "; compile_index (stream, registers[a.get()], length, scale, offset); @@ -347,7 +368,7 @@ void compile_index(std::ostringstream &stream, stream << " " << registers[this] << " = "; #ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { - if constexpr (float_base) { + if constexpr (jit::float_base) { if constexpr (complex_scalar) { stream << "to_cmp_float(tex1D ("; } else { @@ -835,25 +856,56 @@ void compile_index(std::ostringstream &stream, constant_cast(this->right).get()) { const T l = (this->left->evaluate().at(0) + x_offset)/x_scale; const T r = (this->right->evaluate().at(0) + y_offset)/y_scale; - const size_t i = std::min(static_cast (std::real(l)), - this->get_num_rows() - 1); - const size_t j = std::min(static_cast (std::real(r)), - this->get_num_columns() - 1); - return constant (leaf_node::caches.backends[data_hash][i*this->get_num_columns() + j]); + + if constexpr (jit::float_base) { + const size_t i = std::max (std::min (std::real(l), + this->get_num_rows() - 1), + 0); + const size_t j = std::max (std::min (std::real(r), + this->get_num_columns() - 1), + 0); + return constant (leaf_node::caches.backends[data_hash][i*this->get_num_columns() + j]); + } else { + const size_t i = std::max (std::min (std::real(l), + this->get_num_rows() - 1), + 0); + const size_t j = std::max (std::min (std::real(r), + this->get_num_columns() - 1), + 0); + return constant (leaf_node::caches.backends[data_hash][i*this->get_num_columns() + j]); + } } else if (constant_cast(this->left).get()) { const T l = (this->left->evaluate().at(0) + x_offset)/x_scale; - const size_t i = std::min(static_cast (std::real(l)), - this->get_num_rows() - 1); - return piecewise_1D(leaf_node::caches.backends[data_hash].index_row(i, this->get_num_columns()), - this->right, y_scale, y_offset); + if constexpr (jit::float_base) { + const size_t i = std::max (std::min (std::real(l), + this->get_num_rows() - 1), + 0); + return piecewise_1D(leaf_node::caches.backends[data_hash].index_row(i, this->get_num_columns()), + this->right, y_scale, y_offset); + } else { + const size_t i = std::max (std::min (std::real(l), + this->get_num_rows() - 1), + 0); + return piecewise_1D(leaf_node::caches.backends[data_hash].index_row(i, this->get_num_columns()), + this->right, y_scale, y_offset); + } } else if (constant_cast(this->right).get()) { const T r = (this->right->evaluate().at(0) + y_offset)/y_scale; - const size_t j = std::min(static_cast (std::real(r)), - this->get_num_columns() - 1); - - return piecewise_1D(leaf_node::caches.backends[data_hash].index_column(j, this->get_num_columns()), - this->left, x_scale, x_offset); + + if constexpr (jit::float_base) { + const size_t j = std::max (std::min (std::real(r), + this->get_num_columns() - 1), + 0); + return piecewise_1D(leaf_node::caches.backends[data_hash].index_column(j, this->get_num_columns()), + this->left, x_scale, x_offset); + } else { + const size_t j = std::max (std::min (std::real(r), + this->get_num_columns() - 1), + 0); + return piecewise_1D(leaf_node::caches.backends[data_hash].index_column(j, this->get_num_columns()), + this->left, x_scale, x_offset); + } } if (evaluate().is_same()) { @@ -1015,7 +1067,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(x.get()) == indices.end()) { indices[x.get()] = jit::to_string('i', x.get()); stream << " const " - << jit::smallest_int_type (num_rows) << " " + << jit::smallest_uint_type (num_rows) << " " << indices[x.get()] << " = "; compile_index (stream, registers[x.get()], num_rows, x_scale, x_offset); @@ -1024,7 +1076,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(y.get()) == indices.end()) { indices[y.get()] = jit::to_string('i', y.get()); stream << " const " - << jit::smallest_int_type (num_columns) << " " + << jit::smallest_uint_type (num_columns) << " " << indices[y.get()] << " = "; compile_index (stream, registers[y.get()], num_columns, y_scale, y_offset); @@ -1040,7 +1092,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(temp.get()) == indices.end()) { indices[temp.get()] = jit::to_string('i', temp.get()); stream << " const " - << jit::smallest_int_type (length) << " " + << jit::smallest_uint_type (length) << " " << indices[temp.get()] << " = " << indices[x.get()] << "*" << num_columns << " + " @@ -1056,7 +1108,7 @@ void compile_index(std::ostringstream &stream, stream << " " << registers[this] << " = "; #ifdef USE_CUDA_TEXTURES if constexpr (jit::use_cuda()) { - if constexpr (float_base) { + if constexpr (jit::float_base) { if constexpr (complex_scalar) { stream << "to_cmp_float(tex1D ("; } else { @@ -1075,8 +1127,8 @@ void compile_index(std::ostringstream &stream, if constexpr (jit::use_metal ()) { #ifdef USE_INDEX_CACHE stream << ".read(" - << jit::smallest_int_type (std::max(num_rows, - num_columns)) + << jit::smallest_uint_type (std::max(num_rows, + num_columns)) << "2(" << indices[y.get()] << "," @@ -1461,7 +1513,7 @@ void compile_index(std::ostringstream &stream, #ifdef USE_INDEX_CACHE indices[a.get()] = jit::to_string('i', a.get()); stream << " const " - << jit::smallest_int_type (length) << " " + << jit::smallest_uint_type (length) << " " << indices[a.get()] << " = "; compile_index (stream, registers[a.get()], length, scale, offset); @@ -1797,7 +1849,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(x.get()) == indices.end()) { indices[x.get()] = jit::to_string('i', x.get()); stream << " const " - << jit::smallest_int_type (num_rows) << " " + << jit::smallest_uint_type (num_rows) << " " << indices[x.get()] << " = "; compile_index (stream, registers[x.get()], num_rows, x_scale, x_offset); @@ -1806,7 +1858,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(y.get()) == indices.end()) { indices[y.get()] = jit::to_string('i', y.get()); stream << " const " - << jit::smallest_int_type (num_columns) << " " + << jit::smallest_uint_type (num_columns) << " " << indices[y.get()] << " = "; compile_index (stream, registers[y.get()], num_columns, y_scale, y_offset); @@ -1819,7 +1871,7 @@ void compile_index(std::ostringstream &stream, if (indices.find(temp.get()) == indices.end()) { indices[temp.get()] = jit::to_string('i', temp.get()); stream << " const " - << jit::smallest_int_type (length) << " " + << jit::smallest_uint_type (length) << " " << indices[temp.get()] << " = " << indices[x.get()] << "*" << num_columns << " + " diff --git a/graph_framework/register.hpp b/graph_framework/register.hpp index d7dcdc4..96fdd2a 100644 --- a/graph_framework/register.hpp +++ b/graph_framework/register.hpp @@ -105,7 +105,7 @@ namespace jit { /// @returns The smallest integer type as a string. //------------------------------------------------------------------------------ template - std::string smallest_int_type(const size_t max_size) { + std::string smallest_uint_type(const size_t max_size) { if (max_size <= std::numeric_limits::max()) { if constexpr (jit::use_metal ()) { return "ushort"; diff --git a/graph_tests/piecewise_test.cpp b/graph_tests/piecewise_test.cpp index 5647cb7..267e88b 100644 --- a/graph_tests/piecewise_test.cpp +++ b/graph_tests/piecewise_test.cpp @@ -182,6 +182,11 @@ template void piecewise_1D() { assert(graph::constant_cast(graph::atan(p1, p3)).get() && "Expected a constant node."); + a->set(static_cast (-1.5)); + compile ({graph::variable_cast(a)}, + {p1}, {}, + static_cast (1.0), 0.0); + a->set(static_cast (1.5)); compile ({graph::variable_cast(a)}, {p1}, {}, @@ -244,6 +249,15 @@ template void piecewise_1D() { assert(graph::constant_cast(pc).get() && "Expected a constant."); + auto neg_one = graph::none (); + auto ps = graph::piecewise_1D (std::vector ({static_cast (2.0), + static_cast (4.0), + static_cast (6.0)}), + neg_one, 1.0, 0.0); + auto ps_cast = graph::constant_cast(ps); + assert(ps_cast.get() && "Expected a constant"); + assert(ps_cast->is(2.0) && "Expected a value of 2"); + // fma(p1,c1 + a,p2) -> fma(p1,a,p3) auto fma_combine = fma(p1,1.0 + a,p3); auto fma_combine_cast = graph::fma_cast(fma_combine); @@ -323,6 +337,16 @@ template void piecewise_2D() { assert(pc_cast.get() && "Expected a constant node."); assert(pc_cast->is(4.0) && "Expected a value of 6"); + auto cx2 = graph::constant (static_cast (-0.5)); + auto cy2 = graph::constant (static_cast (-1.5)); + auto pconst2 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, cx2, 1.0, 0.0, cy2, 1.0, 0.0); + auto pc_cast2 = constant_cast(pconst2); + assert(pc_cast2.get() && "Expected a constant node."); + assert(pc_cast2->is(2.0) && "Expected a value of 2"); + auto p1const = graph::piecewise_2D (std::vector ({ static_cast (2.0), static_cast (4.0), static_cast (6.0), static_cast (10.0) @@ -335,6 +359,18 @@ template void piecewise_2D() { assert(buffer[1] == static_cast (4.0) && "Expected a 4 in the second index."); + auto p1const2 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, cx2, 1.0, 0.0, ay, 1.0, 0.0); + auto p1c_cast2 = piecewise_1D_cast(p1const2); + assert(p1c_cast2.get() && "Expected a piecewise constant."); + buffer = p1c_cast2->evaluate(); + assert(buffer[0] == static_cast (2.0) && + "Expected a 2 in the first index."); + assert(buffer[1] == static_cast (4.0) && + "Expected a 4 in the second index."); + auto p2const = graph::piecewise_2D (std::vector ({ static_cast (2.0), static_cast (4.0), static_cast (6.0), static_cast (10.0) @@ -347,6 +383,18 @@ template void piecewise_2D() { assert(buffer[1] == static_cast (10.0) && "Expected a 10 in the second index."); + auto p2const2 = graph::piecewise_2D (std::vector ({ + static_cast (2.0), static_cast (4.0), + static_cast (6.0), static_cast (10.0) + }), 2, ax, 1.0, 0.0, cy2, 1.0, 0.0); + auto p2c_cast2 = piecewise_1D_cast(p2const2); + assert(p2c_cast2.get() && "Expected a piecewise constant."); + buffer = p2c_cast2->evaluate(); + assert(buffer[0] == static_cast (2.0) && + "Expected a 2 in the first index."); + assert(buffer[1] == static_cast (6.0) && + "Expected a 6 in the second index."); + assert(graph::constant_cast(p1*0.0).get() && "Expected a constant node."); @@ -455,6 +503,27 @@ template void piecewise_2D() { assert(graph::piecewise_2D_cast(graph::atan(p1, p5)).get() && "Expected a piecewise_2d node."); + ax->set(static_cast (-1.5)); + ay->set(static_cast (-1.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1}, {}, + static_cast (1.0), 0.0); + + ax->set(static_cast (-1.5)); + ay->set(static_cast (1.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1}, {}, + static_cast (2.0), 0.0); + + ax->set(static_cast (1.5)); + ay->set(static_cast (-1.5)); + compile ({graph::variable_cast(ax), + graph::variable_cast(ay)}, + {p1}, {}, + static_cast (3.0), 0.0); + ax->set(static_cast (1.5)); ay->set(static_cast (1.5)); compile ({graph::variable_cast(ax), @@ -748,6 +817,13 @@ template void index_1D() { graph::variable_cast(arg)}, {index}, {}, static_cast (3.0), 0.0); + + arg->set(static_cast (-3.5)); + + compile ({graph::variable_cast(variable), + graph::variable_cast(arg)}, + {index}, {}, + static_cast (0.0), 0.0); } //------------------------------------------------------------------------------ @@ -764,10 +840,41 @@ template void index_2D() { x, 1.0, 0.0, y, 1.0, 0.0); - variable->set({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}); + variable->set({ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0 + }); x->set(static_cast (2.5)); y->set(static_cast (0.5)); + compile ({graph::variable_cast(variable), + graph::variable_cast(x), + graph::variable_cast(y)}, + {index}, {}, + static_cast (7.0), 0.0); + + x->set(static_cast (-2.5)); + y->set(static_cast (-0.5)); + + compile ({graph::variable_cast(variable), + graph::variable_cast(x), + graph::variable_cast(y)}, + {index}, {}, + static_cast (1.0), 0.0); + + x->set(static_cast (-2.5)); + y->set(static_cast (0.5)); + + compile ({graph::variable_cast(variable), + graph::variable_cast(x), + graph::variable_cast(y)}, + {index}, {}, + static_cast (1.0), 0.0); + + x->set(static_cast (2.5)); + y->set(static_cast (-0.5)); + compile ({graph::variable_cast(variable), graph::variable_cast(x), graph::variable_cast(y)},