AI/LLM Paper review

2022_Interleaving Retrieval with Chain-of-Thought Reasoning for Knowledge-Intensive Multi-Step Questions (IR-RAG 실습)

jhworld 2024. 2. 18. 15:41
다 다시 쓰일 것입니다.

What is IR-RAG

QA 성능 향상을 위해서 CoT를 활용하는데, 매 step 마다 RAG를 수행하여 CoT의 성능을 향상시킨다.

 


개념은 매우 간단하니 이제 실험의 구성 요소를 대충 보자.

- Data

: We evaluate our method on 4 multi-step QA datasets in the open-domain setting: HotpotQA (Yang et al., 2018), 2WikiMul-ihopQA (Ho et al., 2020), answerable subset of MuSiQue (Trivedi et al., 2022), and answerable subset of IIRC (Ferguson et al., 2020).

 

- Retriever

: We use BM25 

 

- Metric

: F1, EM score.


코드를 살펴 본 다음 저 자세하게 보자.

Installation

: 필요한 모듈을 설치한다.

conda create -n ircot python=3.8.0 -y && conda activate ircot
pip install -r requirements.txt

# Python에서 NLP(자연어 처리) 작업을 위해 널리 사용되는 spaCy 라이브러리
# 를 이용하여 en_core_web_sm이라는 영어 모델을 다운로드하고 설치하는 동작을 수행합니다.
python -m spacy download en_core_web_sm

 

Prepare Data

원래 data set (=raw_data)라고 부르고, 이 raw_data를 가지고 와서 이 논문에서 쓰이는 자료 형태로 한번 정재한게 processed_data 라고 부른다. 이 부분은 나중에 더 자세하게 보면 좋을 거 같고 지금 당장은 그냥 processed_data를 바로 받아놓자.

./download/processed_data.sh

 

실행이 끝나면 아래와 같이 데이터 들이 다운 받아져 있는데 어떤 형태인진 아직 잘 모르겠네.

Prepare Prompts

실험에 사용된 모든 prompts가 저장이 되어있다. 그리고 당연히 이런 prompts들은 사람이 쓴게 아니라 코드로 생성 된 것이다. 그리고 여기서는 그 prompt를 생성하는 코드까지 모두 줬다. 

- prompt_generator: prompts를 생성하는 파이썬 코드

- prompts: 위에를 통해 생성된 프롬프트를 저장해 둔 곳.

 

이 부분도 나중에 잘 보면 좋겠지만 지금 당장은 그냥 prompts에 생성 되어있는거 먼저 보자.


 

프로그램 작동을 살펴보면 가장 기본 흐름은 아래와 같다.

reproduce.sh -> runner.py -> run.py.

 

- 대충 살펴보니 reproduce.sh 는 크게 4단계로 이루어져있다.

echo ">>>> Instantiate experiment configs with different HPs and write them in files. <<<<"
python runner.py $1 $2 $3 write --prompt_set 1

echo ">>>> Run experiments for different HPs on the dev set. <<<<"
python runner.py $1 $2 $3 predict --prompt_set 1

echo ">>>> Run evaluation for different HPs on the dev set. <<<<"
python runner.py $1 $2 $3 evaluate --prompt_set 1

echo ">>>> Show results for experiments with different HPs <<<<"
python runner.py $1 $2 $3 summarize --prompt_set 1

 

- $1,$2,$3에 넣을 수 있는 값으로는 아래와 같다.

$1=("ircot" "ircot_qa" "oner" "oner_qa" "nor_qa")
$2=("codex" "flan-t5-xxl" "flan-t5-xl" "flan-t5-large" "flan-t5-base", "none")
$3=("hotpotqa" "2wikimultihopqa" "musique" "iirc")

 

- 나는 차근차근 코드를 따라 들어갈 예정이니 ($1, $2, $3) = (ircot, codex, hotpotqa) 이 1 set에 대해서만 먼저 돌려보겠다. 이러면 runner.py 에 들어갈 argu는 아래와 같다.

echo ">>>> Instantiate experiment configs with different HPs and write them in files. <<<<"
python runner.py ircot codex hotpotqa write --prompt_set 1

echo ">>>> Run experiments for different HPs on the dev set. <<<<"
python runner.py ircot codex hotpotqa predict --prompt_set 1

echo ">>>> Run evaluation for different HPs on the dev set. <<<<"
python runner.py ircot codex hotpotqa evaluate --prompt_set 1

echo ">>>> Show results for experiments with different HPs <<<<"
python runner.py ircot codex hotpotqa summarize --prompt_set 1

 

 

- 이 argu가 들어갈 때, 최종 적으로 run.py에는 어떤 식으로 argu가 들어가는지 확인 하기 위해서 아래와 같이 runner.py에 debug 걸고 run.py가 돌기 직전에 break point 잡은 다음 argu를 넣어준다. 

 

그 결과 위에 argu 들은 최종적으로 아래와 같이 run.py에 들어가게 된다.

# python runner.py ircot codex hotpotqa write --prompt_set 1
python run.py write ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --no_diff

# python runner.py ircot codex hotpotqa predict --prompt_set 1
python run.py predict ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl --skip_if_exists --silent

# python runner.py ircot codex hotpotqa evaluate --prompt_set 1
python run.py evaluate ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl

# python runner.py ircot codex hotpotqa summarize --prompt_set 1
python run.py summarize ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl

 

 

- 이제 run.py의 main 함수 가장 첫 줄에 break point 걸고 debug mode로 argu 넣고 한줄씩 따라가면 끝.

 


1. python run.py write ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --no_diff

위의 코드를 돌리면, 그냥 실험 할 때 쓰이는 config를 생성한다. 이게 어떻게 쓰이는진 더 뒤에 보자.

 

2. python run.py predict ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl --skip_if_exists --silent

run.py 속에는 총 7개의 subprocess.call이 있는데. 위에 argu로 돌리면 line 727번에 있는 subprocess.call이 불리고 그때, predict.py가 불린다.

python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent

 

 

우리는 이 중 첫번째 argu로 다시 predict.py를 돌려보자.

predict.py 속에는 또 3개의 subprocess.call이 존재하고, 이 3개가 차례대로 모두 불린다. 불린 command 는 아래와 같다.

# Run predict_command: 
RETRIEVER_HOST=http://localhost RETRIEVER_PORT=8000 LLM_SERVER_HOST=http://localhost LLM_SERVER_PORT=8010 python -m commaqa.inference.configurable_inference --config instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet --input processed_data/hotpotqa/dev_subsampled.jsonl --output predictions/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1/prediction__hotpotqa_to_hotpotqa__dev_subsampled.json --silent

# Run evaluate_command: 
python evaluate.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl

# Run evaluate_command: 
python evaluate.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --official

 

다시 또 한줄 한줄 살펴봐야지 뭐... 아오 귀찮아... 첫번째 명령어를 살펴보면 환경변수 4개를 아래와 같이 선언 한 다음.

RETRIEVER_HOST=http://localhost 
RETRIEVER_PORT=8000 
LLM_SERVER_HOST=http://localhost 
LLM_SERVER_PORT=8010

 

commaqa.inference.configurable_inference.py 를 실행시키고, 그때 argu로 아래의 4개를 넘기네.

 --config instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet
 --input processed_data/hotpotqa/dev_subsampled.jsonl
 --output predictions/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1/prediction__hotpotqa_to_hotpotqa__dev_subsampled.json
 --silent

 

Debugger로 따라 가려고 아래와 같이 환경변수는 임으로 선언해주고, debugger를 run하니 error가 발생하네. 그 이유는  debugger를 "Python Debugger: Current File with Arguments" 로 실행 시키면, dir을 타고 들어가서  configurable_inference.py 를 실행 시키기 때문이다.

 

이걸 해결 하기 위해서는 launch.json에 들어가서 아래와 같이 추가하여 debugger를  "Python Debugger: Module"로 실행 시키고 module의 이름을 "commaqa.inference.configurable_inference" 이렇게 설정해서 debugger가 안타고 들어가게 만들면 된다.

 # launch.json
{
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python Debugger: Module",
            "type": "debugpy",
            "request": "launch",
            "module": "commaqa.inference.configurable_inference",
            "args": "${command:pickArgs}",
        },
        {
            "name": "Python Debugger: Current File with Arguments",
            "type": "debugpy",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": "${command:pickArgs}",
            "subProcess": true,
        },
    ]
}

 

구성을 살펴보면 매우 간단하다.

 

 

build_decompser_and_models 에 들어가면

retriever : "retrieve_and_reset_paragraphs" -> class RetrieveAndResetParagraphsParticipant

cot_reasoning_gen: "step_by_step_cot_gen" -> class StepByStepCOTGenParticipant

exit_controller: "step_by_step_exit_controller" -> class StepByStepExitControllerParticipant

그리고 위의 3 class는 모두 ParticipantModel를 mom class로 가지고 있네.

 

model_map은 그냥 jsonnet으로 부터 parsing한 model의 정보들을 map으로 가지고 있는 것.

 

ModelController는 그냥 model_map을 가지고 있는 녀석이라 생각하면 될듯.

 

근데 결국 decompser도 model_map을 가지고 있는데... 쩝 코드 대충 만들었나보지.

 

이제 load_reader를 보자.

별거 없고 그냥 구현되어 있는게 class MultiParaRCReader 이거 뿐이네..

 

다음에 오면 이제 MultiParaRCReader.read_examples( ) 함수를 살펴보면 될듯. 여기에 input으로 'processed_data/hotpotqa/dev_subsampled.jsonl' 가 들어가거든.