diff --git a/cranelift/jit/src/backend.rs b/cranelift/jit/src/backend.rs index cbed2c3217bf..2f04f67db524 100644 --- a/cranelift/jit/src/backend.rs +++ b/cranelift/jit/src/backend.rs @@ -16,6 +16,7 @@ use cranelift_module::{ }; use log::info; use std::cell::RefCell; +use std::collections::BTreeMap; use std::collections::HashMap; use std::ffi::CString; use std::io::Write; @@ -175,7 +176,9 @@ pub struct JITModule { declarations: ModuleDeclarations, compiled_functions: SecondaryMap>, compiled_data_objects: SecondaryMap>, - code_ranges: Vec<(usize, usize, FuncId)>, + /// Map from a function's start address to its (end address, FuncId), used + /// to resolve a PC back to its function for exception unwinding. + code_ranges: BTreeMap, functions_to_finalize: Vec, data_objects_to_finalize: Vec, } @@ -334,9 +337,6 @@ impl JITModule { data.perform_relocations(|name| self.get_address(name)); } - self.code_ranges - .sort_unstable_by_key(|(start, _end, _)| *start); - // Now that we're done patching, prepare the memory for execution! let branch_protection = if cfg!(target_arch = "aarch64") && use_bti(&self.isa.isa_flags()) { BranchProtection::BTI @@ -367,7 +367,7 @@ impl JITModule { declarations: ModuleDeclarations::default(), compiled_functions: SecondaryMap::new(), compiled_data_objects: SecondaryMap::new(), - code_ranges: Vec::new(), + code_ranges: BTreeMap::new(), functions_to_finalize: Vec::new(), data_objects_to_finalize: Vec::new(), } @@ -381,24 +381,10 @@ impl JITModule { &'a self, pc: usize, ) -> Option<(usize, wasmtime_unwinder::ExceptionTable<'a>)> { - // Search the sorted code-ranges for the PC. - let idx = match self - .code_ranges - .binary_search_by_key(&pc, |(start, _end, _func)| *start) - { - Ok(exact_start_match) => Some(exact_start_match), - Err(least_upper_bound) if least_upper_bound > 0 => { - let last_range_before_pc = &self.code_ranges[least_upper_bound - 1]; - if last_range_before_pc.0 <= pc && pc < last_range_before_pc.1 { - Some(least_upper_bound - 1) - } else { - None - } - } - _ => None, - }?; - - let (start, _, func) = self.code_ranges[idx]; + let (&start, &(end, func)) = self.code_ranges.range(..=pc).next_back()?; + if pc >= end { + return None; + } // Get the ExceptionTable. The "parse" here simply reads two // u32s for lengths and constructs borrowed slices, so it's @@ -521,8 +507,7 @@ impl Module for JITModule { let range_start = ptr.addr(); let range_end = range_start + size; - // These will be sorted when we finalize. - self.code_ranges.push((range_start, range_end, id)); + self.code_ranges.insert(range_start, (range_end, id)); self.functions_to_finalize.push(id);