This is the original implementation of "CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking" by Chia-Hsuan Lee, Hao Cheng and Mari Ostendorf.
The task is to track user intents predefined by a schema (ontology) in a multi-turn conversation with an agent. CorrectionLM is a novel correction framework that enables Small Language Models (e.g. Llama3-8B) to self-correct using in-context exemplars without LLM (e.g. GPT-4o) involvement.
Installation | Preprocess | Training | Inference | | Evaluation | | Citation
Create a conda environment
conda env create -f env.yml To download and create the MultiWoz 2.4
cd data
sh preprocess_mwoz.shTo download and create the Schema-Guided Dialogue(SGD)
sh preprocess_sgd.shThe trained retrievers are saved in retriever/expts folder. Each subfolder is a trained retriever.
If you want to skip the retriever finetuning part, simply download one of our retriever finetuned on 5% training set and try it.
Download and unzip mwoz_5p_SBERT.zip, put the folder in retriever/expts. (For SGD, sgd_5p_SBERT.zip)
First embed all the utterances with SBERT (all-mpnet-base-v2) by
cd retriever/code/
python pretrained_embed_index.pyThis will save all the embeddings in retriever/expts/all_mpnet_base_v2.
To finetune SBERT with data in ../../data/mw24_5p_train.json, run
python retriever_finetuning.py \
--train_fn ../../data/mw24_5p_train.json \
--save_name mw24_5p_SBERT \
--epoch 15 \
--topk 10 \
--toprange 200This will save the embedding model and pre-embeded selection pool to retriever/expts/mw24_5p_SBERT.
The first step of the training is to obtain SLM predictions using ICL in order to provide supervision signals for the correction training.
For MultiWOZ,
cd runs
python runs/run_mwoz_ICL_5shot.py \
--output_dir expts/mwoz/llama3_on_train5p_zeroshot/ \
--lm meta-llama/Meta-Llama-3-8B-Instruct \
--test_fn data/mw24_5p_train.json \
--mwz_ver 2.4You can also use GPT-4o for comparisons
python runs/run_mwoz_ICL_5shot.py \
--output_dir expts/mwoz/gpt4o_on_train5p_zeroshot/ \
--lm gpt4 \
--test_fn data/mw24_5p_train.json \
--mwz_ver 2.4For SGD,
python runs/run_sgd_ICL_5shot.py \
--lm meta-llama/Meta-Llama-3-8B-Instruct \
--retriever_dir retriever/expts/sgd_5p_SBERT/ \
--output_dir expts/sgd/llama3_on_train5p_zeroshot/ \
--test_fn data/sgd/sgd_train_5p.jsonThen we create the in-context exemplars to finetune the SLM. Unlike traditional ICL methods that only consider the input and gold output, we also incoporate the model’s (erroneous) self predictions.
For MultiWOZ,
python data/create_mwoz_llama_SFT_prompt.py \
--train_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
--retriever_dir retriever/expts/mwoz_5p_SBERT/ \
--output_fn data/mwoz/llama3_on_train5p_zeroshot_ICL_prompt.json \
--test_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
--mwz_ver 2.4For SGD,
python data/create_sgd_llama_SFT_prompt.py \
--train_fn expts/sgd/llama3_on_train5p_zeroshot/running_log.json \
--retriever_dir retriever/expts/sgd_5p_SBERT/ \
--output_fn data/sgd/llama3_on_train5p_zeroshot_ICL_prompt.json \
--test_fn expts/sgd/llama3_on_train5p_zeroshot/running_log.jsonThe second step is to train the SLM. In order to be computation-efficient, we aadopt QLoRA for training, i.e. we quantize the SLM then insert LoRA adapaters. For MultiWOZ,
cd runs
sh train_mwoz.shFor SGD,
sh train_sgd.shThe first step of the inference is to get initial predictions by a non-finetuned SLM.
For MultiWOZ,
python runs/run_mwoz_ICL_vanilla.py \
--train_fn data/mw24_5p_train.json \
--retriever_dir retriever/expts/mwoz_5p_SBERT/ \
--lm meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir expts/mwoz/llama3_train_5p_on_test100p/ \
--test_fn data/mw24_100p_test.json \
--mwz_ver 2.4For SGD,
python runs/run_sgd_ICL_vanilla.py \
--train_fn data/sgd/sgd_train_5p.json \
--retriever_dir retriever/expts/sgd_5p_SBERT/ \
--lm meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir expts/sgd/llama3_train_5p_on_test100p/ \
--test_fn data/sgd/sgd_test_100p.jsonWe then prompt the correction-tuned SLM (correction SLM) to refine the initial predictions made in the first step. For MWOZ,
python runs/run_mwoz_correctionlm.py \
--train_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
--retriever_dir retriever/expts/mwoz_5p_SBERT/ \
--output_dir expts/mwoz/correction_outputs/llama_example_llama_inference_train5p_test100p/ \
--test_fn expts/mwoz/llama3_train_5p_on_test100p/running_log.json \
--mwz_ver 2.4 \
--model models/mwoz/sft_llama3_on_train5p_zeroshot/For SGD,
python runs/run_sgd_correctionlm.py \
--train_fn expts/sgd/llama3_on_train5p_zeroshot/running_log.json \
--retriever_dir retriever/expts/sgd_5p_SBERT/ \
--output_dir expts/sgd/correction_outputs/llama_example_llama_inference_train5p_test100p/ \
--test_fn expts/sgd/llama3_train_5p_on_test100p/running_log.json \
--model models/sgd/sft_llama3_on_train5p_zeroshot/Compute the JGA and F1 for both dialogue level (DST) and turn level (TLB).
python eval/eval_result.py \
--eval_fn expts/mwoz/correction_outputs/llama_example_llama_inference_train5p_test100p/running_log.json \
--eval_mode second_pass # first_pass to score on the results produced by non-finetuned LMIf you find our code or paper useful, please cite the paper:
@article{lee2024correctionlm,
title={CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking},
author={Lee, Chia-Hsuan and Cheng, Hao and Ostendorf, Mari},
journal={arXiv preprint arXiv:2410.18209},
year={2024}
}Please contact Chia-Hsuan Lee (chiahsuan.li[at]gmail.com) for questions and suggestions.
