diff --git a/retriv/base_retriever.py b/retriv/base_retriever.py index 8e7bec6..ca0cc5e 100644 --- a/retriv/base_retriever.py +++ b/retriv/base_retriever.py @@ -77,7 +77,7 @@ def prepare_results(self, doc_ids: List[str], scores: np.ndarray) -> List[dict]: return results def map_internal_ids_to_original_ids(self, doc_ids: Iterable) -> List[str]: - return [self.id_mapping[doc_id] for doc_id in doc_ids] + return [self.id_mapping[doc_id] for doc_id in doc_ids if doc_id != -1] def save(self): raise NotImplementedError() diff --git a/retriv/hybrid_retriever.py b/retriv/hybrid_retriever.py index 431e9ad..e199450 100644 --- a/retriv/hybrid_retriever.py +++ b/retriv/hybrid_retriever.py @@ -251,8 +251,9 @@ def search( List: results. """ - sparse_results = self.sparse_retriever.search(query, False, 1_000) - dense_results = self.dense_retriever.search(query, False, 1_000) + sub_cutoff = max(cutoff, 1_000) + sparse_results = self.sparse_retriever.search(query, False, sub_cutoff) + dense_results = self.dense_retriever.search(query, False, sub_cutoff) hybrid_results = self.merger.fuse([sparse_results, dense_results]) return ( self.prepare_results( @@ -282,8 +283,9 @@ def msearch( Dict: results. """ - sparse_results = self.sparse_retriever.msearch(queries, 1_000) - dense_results = self.dense_retriever.msearch(queries, 1_000, batch_size) + sub_cutoff = max(cutoff, 1_000) + sparse_results = self.sparse_retriever.msearch(queries, sub_cutoff) + dense_results = self.dense_retriever.msearch(queries, sub_cutoff, batch_size) return self.merger.mfuse([sparse_results, dense_results], cutoff) def bsearch(