다 다시 쓰일 것입니다.
What is IR-RAG
QA 성능 향상을 위해서 CoT를 활용하는데, 매 step 마다 RAG를 수행하여 CoT의 성능을 향상시킨다.
개념은 매우 간단하니 이제 실험의 구성 요소를 대충 보자.
- Data
: We evaluate our method on 4 multi-step QA datasets in the open-domain setting: HotpotQA (Yang et al., 2018), 2WikiMul-ihopQA (Ho et al., 2020), answerable subset of MuSiQue (Trivedi et al., 2022), and answerable subset of IIRC (Ferguson et al., 2020).
- Retriever
: We use BM25
- Metric
: F1, EM score.
코드를 살펴 본 다음 저 자세하게 보자.
Installation
: 필요한 모듈을 설치한다.
conda create -n ircot python=3.8.0 -y && conda activate ircot
pip install -r requirements.txt
# Python에서 NLP(자연어 처리) 작업을 위해 널리 사용되는 spaCy 라이브러리
# 를 이용하여 en_core_web_sm이라는 영어 모델을 다운로드하고 설치하는 동작을 수행합니다.
python -m spacy download en_core_web_sm
Prepare Data
원래 data set (=raw_data)라고 부르고, 이 raw_data를 가지고 와서 이 논문에서 쓰이는 자료 형태로 한번 정재한게 processed_data 라고 부른다. 이 부분은 나중에 더 자세하게 보면 좋을 거 같고 지금 당장은 그냥 processed_data를 바로 받아놓자.
./download/processed_data.sh
실행이 끝나면 아래와 같이 데이터 들이 다운 받아져 있는데 어떤 형태인진 아직 잘 모르겠네.
Prepare Prompts
실험에 사용된 모든 prompts가 저장이 되어있다. 그리고 당연히 이런 prompts들은 사람이 쓴게 아니라 코드로 생성 된 것이다. 그리고 여기서는 그 prompt를 생성하는 코드까지 모두 줬다.
- prompt_generator: prompts를 생성하는 파이썬 코드
- prompts: 위에를 통해 생성된 프롬프트를 저장해 둔 곳.
이 부분도 나중에 잘 보면 좋겠지만 지금 당장은 그냥 prompts에 생성 되어있는거 먼저 보자.
프로그램 작동을 살펴보면 가장 기본 흐름은 아래와 같다.
reproduce.sh -> runner.py -> run.py.
- 대충 살펴보니 reproduce.sh 는 크게 4단계로 이루어져있다.
echo ">>>> Instantiate experiment configs with different HPs and write them in files. <<<<"
python runner.py $1 $2 $3 write --prompt_set 1
echo ">>>> Run experiments for different HPs on the dev set. <<<<"
python runner.py $1 $2 $3 predict --prompt_set 1
echo ">>>> Run evaluation for different HPs on the dev set. <<<<"
python runner.py $1 $2 $3 evaluate --prompt_set 1
echo ">>>> Show results for experiments with different HPs <<<<"
python runner.py $1 $2 $3 summarize --prompt_set 1
- $1,$2,$3에 넣을 수 있는 값으로는 아래와 같다.
$1=("ircot" "ircot_qa" "oner" "oner_qa" "nor_qa")
$2=("codex" "flan-t5-xxl" "flan-t5-xl" "flan-t5-large" "flan-t5-base", "none")
$3=("hotpotqa" "2wikimultihopqa" "musique" "iirc")
- 나는 차근차근 코드를 따라 들어갈 예정이니 ($1, $2, $3) = (ircot, codex, hotpotqa) 이 1 set에 대해서만 먼저 돌려보겠다. 이러면 runner.py 에 들어갈 argu는 아래와 같다.
echo ">>>> Instantiate experiment configs with different HPs and write them in files. <<<<"
python runner.py ircot codex hotpotqa write --prompt_set 1
echo ">>>> Run experiments for different HPs on the dev set. <<<<"
python runner.py ircot codex hotpotqa predict --prompt_set 1
echo ">>>> Run evaluation for different HPs on the dev set. <<<<"
python runner.py ircot codex hotpotqa evaluate --prompt_set 1
echo ">>>> Show results for experiments with different HPs <<<<"
python runner.py ircot codex hotpotqa summarize --prompt_set 1
- 이 argu가 들어갈 때, 최종 적으로 run.py에는 어떤 식으로 argu가 들어가는지 확인 하기 위해서 아래와 같이 runner.py에 debug 걸고 run.py가 돌기 직전에 break point 잡은 다음 argu를 넣어준다.
그 결과 위에 argu 들은 최종적으로 아래와 같이 run.py에 들어가게 된다.
# python runner.py ircot codex hotpotqa write --prompt_set 1
python run.py write ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --no_diff
# python runner.py ircot codex hotpotqa predict --prompt_set 1
python run.py predict ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl --skip_if_exists --silent
# python runner.py ircot codex hotpotqa evaluate --prompt_set 1
python run.py evaluate ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl
# python runner.py ircot codex hotpotqa summarize --prompt_set 1
python run.py summarize ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl
- 이제 run.py의 main 함수 가장 첫 줄에 break point 걸고 debug mode로 argu 넣고 한줄씩 따라가면 끝.
1. python run.py write ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --no_diff
위의 코드를 돌리면, 그냥 실험 할 때 쓰이는 config를 생성한다. 이게 어떻게 쓰이는진 더 뒤에 보자.
2. python run.py predict ircot_codex_hotpotqa --instantiation_scheme ircot --prompt_set 1 --evaluation_path processed_data/hotpotqa/dev_subsampled.jsonl --skip_if_exists --silent
run.py 속에는 총 7개의 subprocess.call이 있는데. 위에 argu로 돌리면 line 727번에 있는 subprocess.call이 불리고 그때, predict.py가 불린다.
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__4___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__6___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__2.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
python predict.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__8___distractor_count__3.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --silent
우리는 이 중 첫번째 argu로 다시 predict.py를 돌려보자.
predict.py 속에는 또 3개의 subprocess.call이 존재하고, 이 3개가 차례대로 모두 불린다. 불린 command 는 아래와 같다.
# Run predict_command:
RETRIEVER_HOST=http://localhost RETRIEVER_PORT=8000 LLM_SERVER_HOST=http://localhost LLM_SERVER_PORT=8010 python -m commaqa.inference.configurable_inference --config instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet --input processed_data/hotpotqa/dev_subsampled.jsonl --output predictions/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1/prediction__hotpotqa_to_hotpotqa__dev_subsampled.json --silent
# Run evaluate_command:
python evaluate.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl
# Run evaluate_command:
python evaluate.py instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet processed_data/hotpotqa/dev_subsampled.jsonl --official
다시 또 한줄 한줄 살펴봐야지 뭐... 아오 귀찮아... 첫번째 명령어를 살펴보면 환경변수 4개를 아래와 같이 선언 한 다음.
RETRIEVER_HOST=http://localhost
RETRIEVER_PORT=8000
LLM_SERVER_HOST=http://localhost
LLM_SERVER_PORT=8010
commaqa.inference.configurable_inference.py 를 실행시키고, 그때 argu로 아래의 4개를 넘기네.
--config instantiated_configs/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1.jsonnet
--input processed_data/hotpotqa/dev_subsampled.jsonl
--output predictions/ircot_codex_hotpotqa____prompt_set_1___bm25_retrieval_count__2___distractor_count__1/prediction__hotpotqa_to_hotpotqa__dev_subsampled.json
--silent
Debugger로 따라 가려고 아래와 같이 환경변수는 임으로 선언해주고, debugger를 run하니 error가 발생하네. 그 이유는 debugger를 "Python Debugger: Current File with Arguments" 로 실행 시키면, dir을 타고 들어가서 configurable_inference.py 를 실행 시키기 때문이다.
이걸 해결 하기 위해서는 launch.json에 들어가서 아래와 같이 추가하여 debugger를 "Python Debugger: Module"로 실행 시키고 module의 이름을 "commaqa.inference.configurable_inference" 이렇게 설정해서 debugger가 안타고 들어가게 만들면 된다.
# launch.json
{
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Module",
"type": "debugpy",
"request": "launch",
"module": "commaqa.inference.configurable_inference",
"args": "${command:pickArgs}",
},
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": "${command:pickArgs}",
"subProcess": true,
},
]
}
구성을 살펴보면 매우 간단하다.
build_decompser_and_models 에 들어가면
retriever : "retrieve_and_reset_paragraphs" -> class RetrieveAndResetParagraphsParticipant
cot_reasoning_gen: "step_by_step_cot_gen" -> class StepByStepCOTGenParticipant
exit_controller: "step_by_step_exit_controller" -> class StepByStepExitControllerParticipant
그리고 위의 3 class는 모두 ParticipantModel를 mom class로 가지고 있네.
model_map은 그냥 jsonnet으로 부터 parsing한 model의 정보들을 map으로 가지고 있는 것.
ModelController는 그냥 model_map을 가지고 있는 녀석이라 생각하면 될듯.
근데 결국 decompser도 model_map을 가지고 있는데... 쩝 코드 대충 만들었나보지.
이제 load_reader를 보자.
별거 없고 그냥 구현되어 있는게 class MultiParaRCReader 이거 뿐이네..
다음에 오면 이제 MultiParaRCReader.read_examples( ) 함수를 살펴보면 될듯. 여기에 input으로 'processed_data/hotpotqa/dev_subsampled.jsonl' 가 들어가거든.
'AI > LLM Paper review' 카테고리의 다른 글
Paper Review "Retrieval-Augmented Generation for Large Language Models: A Survey" (한국어) (0) | 2024.02.02 |
---|