Skip to content

Chia-Hsuan-Lee/CorrectionLM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking

This is the original implementation of "CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking" by Chia-Hsuan Lee, Hao Cheng and Mari Ostendorf.

The task is to track user intents predefined by a schema (ontology) in a multi-turn conversation with an agent. CorrectionLM is a novel correction framework that enables Small Language Models (e.g. Llama3-8B) to self-correct using in-context exemplars without LLM (e.g. GPT-4o) involvement.

Installation | Preprocess | Training | Inference | | Evaluation | | Citation

Installation

Create a conda environment

conda env create -f env.yml 

Download and Preprocess Data

To download and create the MultiWoz 2.4

cd data
sh preprocess_mwoz.sh

To download and create the Schema-Guided Dialogue(SGD)

sh preprocess_sgd.sh

Retriever

The trained retrievers are saved in retriever/expts folder. Each subfolder is a trained retriever.

retriever quickstart

If you want to skip the retriever finetuning part, simply download one of our retriever finetuned on 5% training set and try it. Download and unzip mwoz_5p_SBERT.zip, put the folder in retriever/expts. (For SGD, sgd_5p_SBERT.zip)

retriever details

First embed all the utterances with SBERT (all-mpnet-base-v2) by

cd retriever/code/
python pretrained_embed_index.py

This will save all the embeddings in retriever/expts/all_mpnet_base_v2. To finetune SBERT with data in ../../data/mw24_5p_train.json, run

python retriever_finetuning.py \
--train_fn ../../data/mw24_5p_train.json \
--save_name mw24_5p_SBERT \
--epoch 15 \
--topk 10 \
--toprange 200

This will save the embedding model and pre-embeded selection pool to retriever/expts/mw24_5p_SBERT.

Training

The first step of the training is to obtain SLM predictions using ICL in order to provide supervision signals for the correction training.

For MultiWOZ,

cd runs
python runs/run_mwoz_ICL_5shot.py \
      --output_dir expts/mwoz/llama3_on_train5p_zeroshot/  \
      --lm meta-llama/Meta-Llama-3-8B-Instruct \
      --test_fn data/mw24_5p_train.json \
      --mwz_ver 2.4

You can also use GPT-4o for comparisons

python runs/run_mwoz_ICL_5shot.py \
      --output_dir expts/mwoz/gpt4o_on_train5p_zeroshot/  \
      --lm gpt4 \
      --test_fn data/mw24_5p_train.json \
      --mwz_ver 2.4

For SGD,

python runs/run_sgd_ICL_5shot.py \
      --lm meta-llama/Meta-Llama-3-8B-Instruct \
      --retriever_dir retriever/expts/sgd_5p_SBERT/ \
      --output_dir expts/sgd/llama3_on_train5p_zeroshot/  \
      --test_fn data/sgd/sgd_train_5p.json

Then we create the in-context exemplars to finetune the SLM. Unlike traditional ICL methods that only consider the input and gold output, we also incoporate the model’s (erroneous) self predictions.

For MultiWOZ,

python data/create_mwoz_llama_SFT_prompt.py \
      --train_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
      --retriever_dir retriever/expts/mwoz_5p_SBERT/ \
      --output_fn data/mwoz/llama3_on_train5p_zeroshot_ICL_prompt.json  \
      --test_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
      --mwz_ver 2.4

For SGD,

python data/create_sgd_llama_SFT_prompt.py \
      --train_fn expts/sgd/llama3_on_train5p_zeroshot/running_log.json \
      --retriever_dir retriever/expts/sgd_5p_SBERT/ \
      --output_fn data/sgd/llama3_on_train5p_zeroshot_ICL_prompt.json  \
      --test_fn expts/sgd/llama3_on_train5p_zeroshot/running_log.json

The second step is to train the SLM. In order to be computation-efficient, we aadopt QLoRA for training, i.e. we quantize the SLM then insert LoRA adapaters. For MultiWOZ,

cd runs
sh train_mwoz.sh

For SGD,

sh train_sgd.sh

Inference

The first step of the inference is to get initial predictions by a non-finetuned SLM.

For MultiWOZ,

python runs/run_mwoz_ICL_vanilla.py \
      --train_fn data/mw24_5p_train.json \
      --retriever_dir retriever/expts/mwoz_5p_SBERT/ \
      --lm meta-llama/Meta-Llama-3-8B-Instruct \
      --output_dir expts/mwoz/llama3_train_5p_on_test100p/  \
      --test_fn data/mw24_100p_test.json \
      --mwz_ver 2.4

For SGD,

python runs/run_sgd_ICL_vanilla.py \
      --train_fn data/sgd/sgd_train_5p.json \
      --retriever_dir retriever/expts/sgd_5p_SBERT/ \
      --lm meta-llama/Meta-Llama-3-8B-Instruct \
      --output_dir expts/sgd/llama3_train_5p_on_test100p/  \
      --test_fn data/sgd/sgd_test_100p.json

We then prompt the correction-tuned SLM (correction SLM) to refine the initial predictions made in the first step. For MWOZ,

python runs/run_mwoz_correctionlm.py \
    --train_fn expts/mwoz/llama3_on_train5p_zeroshot/running_log.json \
    --retriever_dir retriever/expts/mwoz_5p_SBERT/ \
    --output_dir expts/mwoz/correction_outputs/llama_example_llama_inference_train5p_test100p/  \
    --test_fn expts/mwoz/llama3_train_5p_on_test100p/running_log.json \
    --mwz_ver 2.4 \
    --model models/mwoz/sft_llama3_on_train5p_zeroshot/

For SGD,

python runs/run_sgd_correctionlm.py \
    --train_fn  expts/sgd/llama3_on_train5p_zeroshot/running_log.json \
    --retriever_dir retriever/expts/sgd_5p_SBERT/ \
    --output_dir expts/sgd/correction_outputs/llama_example_llama_inference_train5p_test100p/  \
    --test_fn expts/sgd/llama3_train_5p_on_test100p/running_log.json \
    --model models/sgd/sft_llama3_on_train5p_zeroshot/

Evaluation

Compute the JGA and F1 for both dialogue level (DST) and turn level (TLB).

python eval/eval_result.py \
      --eval_fn expts/mwoz/correction_outputs/llama_example_llama_inference_train5p_test100p/running_log.json \
      --eval_mode second_pass # first_pass to score on the results produced by non-finetuned LM

Citation and Contact

If you find our code or paper useful, please cite the paper:

@article{lee2024correctionlm,
  title={CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking},
  author={Lee, Chia-Hsuan and Cheng, Hao and Ostendorf, Mari},
  journal={arXiv preprint arXiv:2410.18209},
  year={2024}
}

Please contact Chia-Hsuan Lee (chiahsuan.li[at]gmail.com) for questions and suggestions.

About

Source code for CorrectionLM: Self-Corrections with SLM for Dialogue State Tracking

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors