From b72de2d8dae2441623210a333177aaa34570272f Mon Sep 17 00:00:00 2001 From: ai-leehm <97leehm@gmail.com> Date: Sat, 25 May 2024 20:03:53 +0900 Subject: [PATCH] Update run_short_form.py --- retrieval_lm/run_short_form.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/retrieval_lm/run_short_form.py b/retrieval_lm/run_short_form.py index 034e298..a7f9476 100644 --- a/retrieval_lm/run_short_form.py +++ b/retrieval_lm/run_short_form.py @@ -76,13 +76,13 @@ def call_model_rerank_w_scores_batch(prompt, evidences, model, max_new_tokens=15 if id not in pred_log_probs[0]: score_dict[tok] = -100 prob = pred_log_probs[0][id] - score_dict[tok] = float(prob) + score_dict[tok] = np.exp(float(prob)) do_retrieve = score_dict["[Retrieval]"] / ( score_dict["[Retrieval]"] + score_dict["[No Retrieval]"]) > threshold else: do_retrieve = "[Retrieval]" in pred - if do_retrieve is True: + if do_retrieve: evidence_augmented_inputs = [prompt + "[Retrieval]{0}\n{1}".format( para["title"], para["text"]) for para in evidences] sampling_params = SamplingParams(