diff --git a/library/src/amd_detail/rocblaslt/src/UserDrivenTuningParser.cpp b/library/src/amd_detail/rocblaslt/src/UserDrivenTuningParser.cpp index 93a5c3f721..e0c5f960d9 100644 --- a/library/src/amd_detail/rocblaslt/src/UserDrivenTuningParser.cpp +++ b/library/src/amd_detail/rocblaslt/src/UserDrivenTuningParser.cpp @@ -43,7 +43,7 @@ namespace TensileLite if(problemSolution.second > 0) { - auto sol_iter = m_override.find(problemSolution.first); + auto sol_iter = m_override.find_range(problemSolution.first); for(auto sol_idx = sol_iter.first; sol_idx != sol_iter.second; sol_idx++) { if(sol_idx->second == problemSolution.second) diff --git a/library/src/amd_detail/rocblaslt/src/include/UserDrivenTuningParser.hpp b/library/src/amd_detail/rocblaslt/src/include/UserDrivenTuningParser.hpp index 5d646d822c..af79e59cdb 100644 --- a/library/src/amd_detail/rocblaslt/src/include/UserDrivenTuningParser.hpp +++ b/library/src/amd_detail/rocblaslt/src/include/UserDrivenTuningParser.hpp @@ -160,27 +160,23 @@ namespace TensileLite int size() { - std::shared_lock lock(m_mutex); auto size = m_override.size(); return size; } - auto find(const ProblemOverride& prob_key) + auto find_range(const ProblemOverride& prob_key) { - std::shared_lock lock(m_mutex); auto iter = m_override.equal_range(prob_key); return iter; } void add(const std::pair& problemSolution) { - std::lock_guard lock(m_mutex); m_override.insert(problemSolution); } void erase(std::multimap::iterator& sol_idx) { - std::lock_guard lock(m_mutex); m_override.erase(sol_idx); } @@ -192,7 +188,6 @@ namespace TensileLite private: std::multimap m_override; std::mutex m_guard; - std::shared_timed_mutex m_mutex; }; } // namespace Tensile diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp index 0693a51f94..3f9f852a05 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp @@ -126,7 +126,7 @@ bool problem_override_from_file(rocblaslt_handle& handle, std::vector overrideResults; std::vector solutionIndex(1); TensileLite::ProblemOverride prob_key(RocblasltContractionProblem2ProblemOverride(problem)); - auto sol_iter = m_override.find(prob_key); + auto sol_iter = m_override.find_range(prob_key); for(auto sol_idx = std::make_reverse_iterator(sol_iter.second); !success && sol_idx != std::make_reverse_iterator(sol_iter.first); @@ -196,7 +196,7 @@ bool problem_override_from_file_cpp( std::vector overrideResults; std::vector solutionIndex(1); TensileLite::ProblemOverride prob_key(TensileDataGemm2ProblemOverride(gemmData)); - auto sol_iter = m_override.find(prob_key); + auto sol_iter = m_override.find_range(prob_key); for(auto sol_idx = std::make_reverse_iterator(sol_iter.second); !success && sol_idx != std::make_reverse_iterator(sol_iter.first);