본문 바로가기
코딩/LLM

DPO LLM 강화학습법에 대해서

by 킹형 2024. 1. 24.
728x90
300x250

DPO는 RLHF(Reinforcement Learning from Human Feedback)에 의존하지 않고 사용자 선호도 데이터를 직접 사용하여 언어 모델(LM)을 최적화하는 방법입니다.

주요 목표는 고품질 출력을 생성하기 위해 사용자 선호도를 기반으로 언어 모델을 훈련하는 것입니다.
DPO는 강화학습 정책(PPO와 같은 것)을 사용하지 않으면서도(reward 모델링 없이) reward 함수와 기존 정책을 연결하여 인간의 선호 데이터에 최적화할 수 있다고 논문에서 설명합니다.
논문에 따르면 RLHF로 finetuning한 모델과 비교했을 때, DPO는 요약, single-turn 문제에서 더 우수한 성능을 보였습니다.

DPO의 동기:
DPO는 RLHF에 의존하지 않고 코드 생성과 같은 작업을 위한 언어 모델을 개선하는 과제를 해결합니다. 언어 모델 정책을 개선하기 위해 사용자 선호도 데이터를 직접 활용하는 지도 학습 접근 방식을 제안합니다.
DPO는 수십억 개의 토큰으로 구성된 대규모 데이터 세트에서 운영되며 자기 지도 및 자기 주도 학습을 활용하여 다양한 지식 콘텐츠를 획득합니다.

선호 기반 보상 메커니즘:
DPO는 인간의 선호도를 보상 메커니즘으로 활용하여 사용자 선호도에 따라 언어 모델을 최적화합니다. 이 접근 방식은 RLHF에만 의존하지 않고 선호하는 결과를 생성하기 위해 모델 정책을 개선하는 데 중점을 둡니다.

DPO 목표 및 손실 분석:
Cross-entropy loss function used in DPO.
DPO에서 reward model이 필요없음을 시사하고 있음 충분히 언어모델은 이미 reward model 역할을 수행하고 있다
언어 모델이 본질적으로 보상 모델 자체로 작용하기 때문에 두 번째 단계를 건너 뛸 수 있음을 보여줍니다.
RLHF의 두 번째 단계가 제거되면, 문제는 그림와 같이 크로스 엔트로피 목표를 가진 최적화 문제로 크게 단순화 된다.

DPO는 명시적인 보상 모델링 단계를 건너뛰고 최적화 문제를 이진 분류 작업으로 구성하여 음의 로그 라이클리후드 손실을 최소화합니다.
리워드 모을 사용하지 않고 목적을 구성하고, KL 다이버전스가 0에 가까워지면 최종 최적 솔루션이 달성되어 언어 모델이 선호 분포에 맞춰집니다.
이러한 요점은 직접 선호도 최적화와 관련된 주요 개념 및 기술에 대한 개요를 제공하며 사용자 선호도를 기반으로 언어 모델을 훈련하는 대안적인 접근 방식을 제공합니다.

코드
https://huggingface.co/datasets/Intel/orca_dpo_pairs
https://huggingface.co/docs/trl/main/en/dpo_trainer

데이터셋으로는 chosen과 rejected 2가지가 쌍으로 준비되어야된다.
intel의 orca dpo pair의 경우, GPT로 생성한 데이터와 llama2-13b-chat 2가지 생성된 것이 있는데 무조건 llama2 13b가 GPT보다 않좋다고 할 수 없지만 합성 데이터를 통해 대체로 좋은 경향의 GPT의 학습하려고 하는 것
학습 자체는 trl 라이브러리의 DPOTrainer가 구현 되어있어서 데이터셋만 제대로 준비해준다면 학습이 된다. LoRA에도 대응되고 full fine tuning도 대응한다.

def chatml_format(example):
    # Format system
    if len(example['system']) > 0:
        message = {"role": "system", "content": example['system']}
        system = tokenizer.apply_chat_template([message], tokenize=False)
    else:
        system = ""

    # Format instruction
    message = {"role": "user", "content": example['question']}
    prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)

    # Format chosen answer
    chosen = example['chatgpt'] + "<|im_end|>\n"
    # Format rejected answer
    rejected = example['llama2-13b-chat'] + "<|im_end|>\n"

    return {
        "prompt": system + prompt,
        "chosen": chosen,
        "rejected": rejected,
    }

# Load dataset
dataset = load_dataset("Intel/orca_dpo_pairs")['train']

# Save columns
original_columns = dataset.column_names

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Format dataset
dataset = dataset.map(
    chatml_format,
    remove_columns=original_columns
)
# LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
)

# Model to fine-tune
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)
model.config.use_cache = False

# Reference model
ref_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)

# Training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    max_steps=200,
    save_strategy="no",
    logging_steps=1,
    output_dir=new_model,
    optim="paged_adamw_32bit",
    warmup_steps=100,
    bf16=True,
    report_to="wandb",
)

# Create DPO trainer
dpo_trainer = DPOTrainer(
    model,
    ref_model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    beta=0.1,
    max_prompt_length=1024,
    max_length=1536,
)

# Fine-tune model with DPO
dpo_trainer.train()

참고자료

직접 선호도 최적화(DPO)

https://pakhapoomsarapat.medium.com/forget-rlhf-because-dpo-is-what-you-actually-need-f10ce82c9b95

https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb

728x90
300x250