본문 바로가기
코딩/LLM

LLM2Vec 디코더 전용 LLM을 텍스트 인코더로 변환하는 방법

by 킹형 2024. 8. 8.
728x90
300x250

LLM2Vec: 디코더 전용 LLM을 텍스트 인코더로 변환하는 방법

LLM2Vec은 강력한 텍스트 임베딩을 위해 디코더 전용 대규모 언어 모델(LLM)을 활용하는 새로운 접근 방식을 도입하여 BERT와 같은 기존 인코더 기반 방법과 대조됩니다. 이 백서에서는 텍스트 임베딩 작업에 디코더 전용 모델을 사용하는 방법, 실험 및 그 효과를 입증하는 결과에 대해 설명합니다.

기존의 텍스트 임베딩 방법론:
인코더 모델(예: BERT)은 입력의 모든 토큰을 동시에 고려하여 양방향 주의를 사용합니다.

디코더 모델(예: GPT)은 인과주의를 사용하여 현재 위치까지의 과거 토큰에만 초점을 맞춥니다.

디코더 모델 관련 문제: 자동 회귀 특성은 향후 토큰에 주의를 기울일 수 없기 때문에 차선책 임베딩으로 이어집니다.

LLM2Vec은 대규모 언어 모델(LLM)을 효율적인 텍스트 인코더로 변환하는 새로운 접근 방식입니다.

image

  1. Bidirectional Attention:
    • 전통적인 디코더 모델(GPT 등)은 어텐션를 사용하여 현재 토큰이 이전 토큰에만 주의를 기울이게 합니다.
    • LLM2Vec은 인과 주의를 양방향 주의로 변환하여 모든 토큰이 서로 주의를 기울일 수 있게 합니다.
  2. 마스크된 다음 토큰 예측(MNTP):
    • MNTP는 인코더 모델(BERT 등)에서 사용되는 마스크된 언어 모델링(MLM)과 유사합니다.
    • 입력 토큰의 일부를 무작위로 마스킹하고, 이러한 마스킹된 토큰을 예측하도록 모델을 훈련하여 문맥적 이해를 향상시킵니다.
  3. 비지도 대조 학습(SimCSE):
    • SimCSE에서 영감을 받아, 서로 다른 드롭아웃 마스크를 사용하여 동일한 문장에서 긍정적 쌍과 부정적 쌍을 생성합니다.
    • 모델은 이를 통해 강력한 텍스트 임베딩을 학습합니다.

주요 실험 및 결과

  • 모델 및 훈련:
    • Sheared-LLaMA-1.3B, Llama-2-7B-chat, Mistral-7B-Instruct-v0.2 모델을 사용하여 영어 Wikipedia에서 훈련.
    • Massive Text Embedding Benchmark(MTEB)에서 평가
  • 성능 향상:
    • 양방향 주의 방식으로 성능이 크게 향상되었으며, 주의 깊은 적용과 훈련이 필요함.
    • LLM2Vec 모델은 다양한 벤치마크에서 기존 모델보다 뛰어난 성능을 보여줌.

결론 및 실용적인 적용

  • 결론:
    • LLM2Vec은 디코더 전용 LLM을 강력한 텍스트 인코더로 변환하여, 인과 주의의 한계를 해결하고 다양한 텍스트 임베딩 작업에서 성능을 향상시킵니다.
    • Massive Text Embedding Benchmark(MTEB)에서 평가에서 높은 순위를 얻을 수 있었습니다.
  • 실용적인 적용:
    • NLP 작업: 향상된 임베딩으로 의미론적 텍스트 유사성, 정보 검색, 클러스터링 등 다양한 NLP 작업에서 성능이 향상될 수 있습니다.
    • 모델 유연성: 이 방법을 사용하면 전통적으로 인코더 기반 모델에서 처리했던 작업에 사전 훈련된 대규모 LLM을 사용할 수 있습니다.

자세한 내용과 구현은 arXiv의 전체 논문과 GitHub의 관련 코드를 참조하세요.

https://github.com/McGill-NLP/llm2vec

구현

훈련을 위해서는 llm2vec을 설치할 필요가 있습니다.

먼저 라마3 자체를 llm2vec에 맞게 훈련을 시켜줄 필요가 있습니다. mntp 방식으로 훈련을 먼저 시켜보도록하겠습니다.

lora를 통해서 빠르게 학습이가능합니다. 원하는 커스텀 데이터가 필요하다면 json형식에 text 구조를 지켜주시면 됩니다.

해당 구조에 맞춰서 데이터가 준비 되었으면 config를 수정해줍니다. train_configs/mntp/MetaLlama3.json에 필요한 config가 들어가 있습니다.

훈련을 돌리면 됩니다.

[
  {
    "text": ""
  }
]
nohup python -u experiments/run_mntp.py train_configs/mntp/MetaLlama3.json >nohup.out &
[INFO|trainer.py:641] 2024-06-17 29:02:31,355 >> Using auto half precision backend
[2024-06-17 19:02:31,565] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[INFO|trainer.py:2078] 2024-06-17 19:02:31,761 >> ***** Running training *****
[INFO|trainer.py:2079] 2024-06-17 19:02:31,761 >>   Num examples = 45,554
[INFO|trainer.py:2080] 2024-06-17 19:02:31,761 >>   Num Epochs = 3
[INFO|trainer.py:2081] 2024-06-17 19:02:31,761 >>   Instantaneous batch size per device = 6
[INFO|trainer.py:2084] 2024-06-17 19:02:31,761 >>   Total train batch size (w. parallel, distributed & accumulation) = 24
[INFO|trainer.py:2085] 2024-06-17 19:02:31,761 >>   Gradient Accumulation steps = 4
[INFO|trainer.py:2086] 2024-06-17 19:02:31,761 >>   Total optimization steps = 5,694
[INFO|trainer.py:2087] 2024-06-17 19:02:31,765 >>   Number of trainable parameters = 567,279,616

이렇게 되면 구동 완료

실행

```
import torch
from llm2vec import LLM2Vec

l2v = LLM2Vec.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16,
)

Encoding queries using instructions

instruction = (
"Given a web search query, retrieve relevant passages that answer the query:"
)
queries = [
[instruction, "how much protein should a female eat"],
[instruction, "summit define"],
]
q_reps = l2v.encode(queries)

Encoding documents. Instruction are not required for documents

documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
]
d_reps = l2v.encode(documents)

Compute cosine similarity

q_reps_norm = torch.nn.functional.normalize(q_reps, p=2, dim=1)
d_reps_norm = torch.nn.functional.normalize(d_reps, p=2, dim=1)
cos_sim = torch.mm(q_reps_norm, d_reps_norm.transpose(0, 1))

print(cos_sim)
"""
tensor([[0.6470, 0.1619],

728x90
300x250