DeepLearing/NLP(Reasoning)

[논문리뷰](25.03)KIMI K1.5:SCALING REINFORCEMENT LEARNING WITH LLMS

notdecidedyet 2025. 3. 29. 12:46

2.1 RL Prompt Set Curation: 고품질 데이터를 정의함.

2.2 이 프롬프트에는 planning, evaluation, reflection, exploration과 같은 인지 프로세스(프롬프트)를 사용함.

2.3.4 (Ankner et al. 2024; McAleese et al. 2024)는 chain-of-thought(CoT) 추론으로 증강된 보상 모델이 특히 미묘한 정확성 기준이 중요한 과제...

2.5.2 Vanilla Supervised Finetuning: 수학과 코딩 같은 추론 작업에서는 규칙 기반 및 보상 모델링 기반 검증이 인간 판단보다 정확하고 효율적이므로, rejection sampling을 사용해 SFT 데이터셋을 확장

Abstract

 

  • 언어 모델 사전 훈련의 한계:
    • Next token prediction은 쉽고 효과적이지만 사용 가능한 훈련 데이터의 양이 제한적
    • 이전 발표된 강화학습은 좋은 결과를 내지 못함
  • Kimi k1.5 소개:
    • 최신 멀티모달 LLM으로 RL로 훈련됨
    • RL 훈련 기법, 멀티모달 데이터 레시피, 인프라 최적화 포함
  • 핵심 접근법:
    • 롱 컨텍스트 확장
    • 개선된 policy 최적화 방법
    • 간단하고 효과적인 RL 프레임워크 구축(w/o MCTS, value functions, PRM)

 

Introduction

  • Scaling law에 따라 모델 파라미터와 데이터 크기를 비례적으로 확장하나, 사용 가능한 고품질 훈련 데이터의 양이 제한됨.
  • Kimi k1.5의 목표:
    • 강화학습을 통한 지속적 확장, 새로운 방향 탐색
    • 모델이 보상을 통해 학습하고 탐색하여 정적 데이터셋 제한 극복
  • 핵심 설계 요소:
    1. 롱 컨텍스트 확장:
      • RL의 컨텍스트 윈도우를 128k로 확장
      • 부분 롤아웃을 사용한 훈련 효율성 개선
      • 이전 궤적의 큰 부분을 재사용하여 새 궤적 샘플링
      • 컨텍스트 길이가 LLM과 RL지속적 확장을 위한 핵심 차원임을 확인
    2. 개선된 policy 최적화:
      • Long CoT와 함께 RL 공식 도출
      • 강력한 정책 최적화를 위한 온라인 mirror descent 변형 사용
      • 효과적인 샘플링 전략, 길이 패널티, 데이터 레시피 최적화
    3. 단순화된 프레임워크:
      • 롱 컨텍스트 확장개선된 policy 최적화 방법 결합
      • 학습된 CoT는 계획, 반성, 수정 특성 구현
      • 컨텍스트 길이 증가로 검색 단계 수 증가 효과 (MCTS와 같은 복잡한 것 사용 안해도 됌)
    4. 멀티모달리티: 텍스트와 비전 데이터에 공동 훈련
  • Long2short 방법:
    • 롱 CoT 기법으로 short CoT 모델 개선
    • 롱 CoT 활성화로 길이 패널티 적용
    • 모델 병합 활용

 

2. Approach: Reinforcement Learning with LLMs

Kimi k1.5의 개발 과정: Pretraining - SFT - long-CoT supervised fine-turning - RL

 

2.1 RL Prompt Set Curation

사전 실험을 통해, RL prompt set의 품질과 다양성이 RL의 효과를 좌우하는것을 발견함. 잘 구성된 prompt set데이터는 모델이 강력한 추론을 하도록 하는 것 뿐아니라, reward hacking에 과적합 되는 위험을 완화한다. 고품질 데이터를 아래와 같이 정의함.

  1. 다양한 도메인(Diverse Coverage): STEM, 코딩, 일반 추론과 같은 다양한 분야 포함
    -> 모델의 적응성 향상 및 다양한 영역에서 적용 가능성 보장
  2. 균형 잡힌 난이도(Balanced Difficulty): 쉬운, 중간, 어려운 문제가 잘 분포된 범위 포함
    -> 점진적 학습 촉진 및 특정 복잡성 수준에 과적합 방지
  3. 정확한 평가 가능성(Accurate Evaluability): 객관적이고 신뢰할 수 있는 평가 허용
    -> 모델 성능이 올바른 추론에 기반하여 측정되는지 확인
  • 다양한 도메인을 위한 접근법:
    • 풍부한 추론이 필요하고, 평가하기 쉬운 데이터를 자동화 방법으로 선택함.
      -> STEM 분야, 경쟁, 일반 추론 작업 등 다양한 영역 문제 포함
    • 텍스트만 있는 데이터와, 이미지-텍스트 질문-답변, 데이터 통합
    • 도메인과 분야별 태깅 시스템 개발 -> 다양한 분야의 데이터를 포함하려 함.
  • 난이도 평가를 위한 모델 기반 접근 방식: 각 prompt에 대해 SFT 모델이 높은 sampling temperature로 10번 답변 생성
    -> 통과율 계산하여 난이도의 프록시로 사용(통과율이 낮을수록 난이도 높음)
  • Reward hacking 방지: 찍어서 맞출 수 있는 문제들 제거
    • 다지선다, 참/거짓, 증명 기반 질문 제거
    • CoT를 거치지 않고 곧바로 답변을 추측해보라고 함(8번 샘플링)
      -> 이때 정답을 맞추면 쉬운 문제라고 간주하여 제거함.

2.2 Long-CoT Supervised Fine-Tuning

2.1을 통해 정제된 데이터를 바탕으로 prompt 엔지니어링을 사용해, 적은양이지만 long-CoT데이터셋을 구성함(고품질, warm up). 

이 프롬프트에는 planning, evaluation, reflection, exploration과 같은 인지 프로세스(프롬프트)를 사용함.

 

2.3 Reinforcement Learning

2.3.1 Problem Setting

  • 기본적으로 LLM은 복잡한 문제를 잘 풀지 못함. 이를 극복하기 위해
    • CoT방법을 활용
    • 트리 확장을 활용
    • 계획 알고리즘을 통해 exploration을 더 함.
      -> Test Time Compute중 하나임. 근데 잘 생각해보면, Token을 많이 생성하도록하면 exploration을 할 수 있을 것 같다는 결론에 다달함. 따라서 RL을 사용해 CoT를 생성하도록 훈련하는 것을 고려함(OpenAI 2024).
      • 보상 모델r이 주어진 문제 x, 정답y에 대해 생성된y의 정확성을 평가한다고 가정하면 보상은 {0,1}임. 
      • objective to optize the policy는 아래와 같음.
        $max E_{(x, y*)~D, (y,z)~pi_{\theta}} [r(x,y,y*]$
      • RL훈련을 확장해서, prompt-based CoT와, planning-agumented CoT모두 활용하는 모델을 훈련하는 것을 목표로 한다. 모델은 이 과정에서, 오류 식별, 되돌아가기, 솔루션 개선과 같은 중요한 계획 기술을 배운다.

2.3.2 Policy Optimization

온라인 정책 mirror descent의 변형을 적용함.

  • $π_{\theta}$: 현재 정책 모델(현재 훈련중인 LLM)
  • x: query
  • $(y_j, z_j)$: j번째 샘플링된 응답, y는 최종 답변, z는 중간 추론
  • y*: x에 대한 정답
  • $r(x,y_j,y*)$: j번째 샘플링된 보상 -> 0이면 오답, 1이면 정답
  • $r̄$ = $mean(r(x, y_1, y*), ..., r(x, y_k, y*))$: 모든 샘플에 대한 보상의 평균. 베이스라인 역할을 한다.
  • $π_{\theta_i}$: 바로 전 에폭에서 훈련된 모델.

뒤에 있는 텀은 모델이 한번에 너무 많이 변하는 것을 제지하는 역할을 함.

결론: DeepSeek R1의 GRPO와 비슷한 느낌의 손실함수이다. 

 

  • 이 논문에서는 Value Network를 제거함.
    • Value는 특정 상태나 행동이 얼마나 좋은지를 평가함. 이 모델은 중간 과정이 얼마나 최종 답변에 기여했는지 평가함.
    • 제외한 이유: 실제 inference상황에서는 시행착오가 있을 수 밖에 없음(LLM이 전지하지 않기에 한번에 문제를 풀 수 없다고 가정함). 근데 value network로 학습을 하게되면 완벽한 경로를 선택하도록 학습한다
      • 이럴 경우 아래를 배울 수 없다고 생각함
        • 오류를 인식하고 수정
        • 잘못된 경로에서 되돌아가는 능력
        • 다양한 접근법을 시도하는 능력
      • 따라서 이 연구는 최종 답변이 맞는지만 보상으로 사용함.

2.3.3 Length Penalty

  • overthinking하는 현상을 발견함. 이는 계산비용, 사용자 경험, 효율성과 관련 있음.
    • 이를 완화하기 위해서 아래와 같은 식을 도입함.

 

  1. 문제 x에 대해 모델이 k개의 서로 다른 응답 (y1, z1), ..., (yk, zk)을 생성
  2. 각 응답의 길이를 len(i)로 측정
  3. 가장 짧은 응답의 길이를 min_len = min_i len(i)로 정의
  4. 가장 긴 응답의 길이를 max_len = max_i len(i)로 정의
  5. 정답을 맞췄을 때 받는 리워드 값 λ을 minmax scale과 같이 변형해서 주어짐.

다만 이 식을 초기부터 적용하면 학습에 부정적이라 점진적으로 시행함.

 

 

2.3.4 샘플링 전략

  1. 다양한 난이도의 문제로 훈련함: 수학 경시문제, 초등 수학문제 등.
  2. RL중 여러번 샘플링 하기 때문에, 각 문제에 대한 성공률을 난이도 측정 도구로사용할 수 있음.
  3. 훈련 효율성을 개선하기 위해, 커리큘럼 샘플링을 진행함.
  4. 우선순위 샘플링도 적용함. 잘 맞추지 못하는 문제에 집중하여 풀도록 하기위해, 잘 못푸는 문제를 샘플링할 확률을 높여서 훈련함.

 

2.3.5 More Details on Training Recipe

  • 코딩 문제: 테스트 케이스를 자동으로 생성하여 테스트함.
  • 수학 문제:
    • Chain-of-Thought RM사용: 훈련함.
      • 입력: 질문, 정답, 모델의 응답
      • 출력: 단계별 추론 과정, 최종 판단(점수)
      • 작동 방식: 점수를 출력하기 전에, 왜 그러한지 이유를 작성함.
      • 데이터 수: 800k 수집함. -> 구체적으로 어떻게 수집을 했는지 안나옴.

 

2.4 Long2Short: Context Compression for Short-CoT Models

좀더 짧게 생성하게 유도하기 위해 다음과 같은 방법을 사용함.

  • 모델 병합: non-reaosning model, reasoning모델의 웨이트들을 가중평균함.
  • Shortest Rejection Sampling: 같은 문제에 대해서 8번 샘플링하고, 가장 짧은 답변으로 SFT
  • DPO: 더 짧은 답변에 대해서 선호도를 주어 훈련함.
  • Long2Short RL: 2.3.3에서 소개한 방법을 진행함.

 

2.5 Other Training Details

2.5.1 Pretraining: 영어, 중국어, 코드, 수학 추론, 지식의 다섯 가지 도메인을 포함. 캡셔닝, 이미지-텍스트 Interleaving, OCR, 지식, QA 데이터셋을 포함한 멀티모달 데이터는 우리 모델이 비전-언어 능력을 습득할 수 있게 함. (Appendix B)

 

2.5.2 Vanilla Supervised Finetuning: 비추론 작업(질문-응답, 글쓰기, 텍스트 처리 등)의 경우, 인간 주석을 통해 초기 시드 데이터셋을 만들어 시드 모델을 훈련. 수학과 코딩 같은 추론 작업에서는 규칙 기반 및 보상 모델링 기반 검증이 인간 판단보다 정확하고 효율적이므로, rejection sampling을 사용해 SFT 데이터셋을 확장

 

 

2.6 RL Infrastructure

2.6.1 Large Scale Reinforcement Learning Training System for LLM

2.6.2 Partial Rollouts for Long CoT RL

논문과 다르게 시간순으로 재구성해서 작성하자면,,

  1. 초기 설정
    • Master가 시스템을 초기화, 훈련을 준비함. Rollout, Trainer에 모델을 로드함. 둘은 같은 모델이나, 시간의 차이가 있음.
      Rollout은 $π_{θ_i}$이고, Trainer는 $π_{θ_{i+1}}$이다. Reward Model도 로드함(파인튜닝된 모델로 정답이 맞는지 틀린지 판단하는 모델). Replay buffer는 rollout데이터를 저장하는 공간임.
  2. 반복적 훈련
    • Rollout단계: Master가 Rollout worker에게 작업을 분배해 다양한 응답을 생성하도록함. 생성된 궤적은 Replay Buffer에 저장됨. 긴 응답의 경우 Partial Rollout을 사용해 여러 차례에 생성함. (partial rollout은 engineering적인것으로 여기서는 다루지 않음.)
    • 훈련: Trainer는 Buffer에서 데이터를 가져옴. 이 데이터를 바탕으로 Policy를 업데이트함. 업데이트된 가중치는 다음 rollout에서 사용될 수 있도록 Rollout worker에 웨이트를 넘겨줌.
  3. 평가 및 피드백
    • 생성된 답은 Reward모델에 의해 평가됨. -> 이를 바탕으로 웨이트가 업데이트됨
  4. 반복 및 종료
    • 위 과정들이 여러번에 걸쳐 계속되고, 모델의 성능이 점진적으로 향상됨.

2.6.3 Hybrid Deployment of Training and Inference

  • 훈련시: Megatron 프레임워크 사용
  • 추론시: vLLM 사용
  • 과정: 훈련 완료 후, Megatron이 GPU offload -> vLLM에 가중치 전달. 추론 완료후 vLLM을 종료하고 다시 Megatron에 GPU onload

 

3. Experiments

3.2 Main Results

 

3.3 Long Context Scaling

RL과 LLM의 스케일링 속성을 연구하기 위해 중간 크기 모델을 사용해서 실험함. 위 그림은 훈련이 진행됨에 따라, 정확도 및 응답길이가 어떻게 변하는지 보여줌.

특히, 어려운 벤치마크는 응답 길이가 더 가파른 증가세를 보임. -> 이는 모델이 복잡한 문제에서 더 정교한 답변을 생성하는 방법을 학습하는 것을 의미함. 

 

 

그림 6은 모델의 컨텍스트 길이와 문제 해결 능력 사이에 강한 상관관계가 있는것을 나타냄. k1.5에서는 128k길이에서 어려운 추론 벤치마크에서 지속적인 개선이 있음을 보여줌.

 

3.4 Long2short

short에서도 long과 비슷한 성능을 보임.

 

3.5 Ablation Studies

모델 크기와 컨텍스트 길이 스케일링: 동일한 데이터셋으로 서로 다른 크기의 모델을 훈련한 결과, 더 큰 모델이 처음에는 더 나은 성능을 보이지만, 더 작은 모델도 RL을 통해 최적화된 긴 CoT를 활용하면 비슷한 성능에 도달할 수 있었다. 다만 더 큰 모델이 일반적으로 더 나은 토큰 효율성을 보여줌. 이는 최상의 성능을 목표로 한다면 더 큰 모델의 컨텍스트 길이를 스케일링하는 것이 유리하지만, 계산 리소스에 제약이 있는 경우 긴 컨텍스트 길이를 가진 작은 모델도 실용적인 대안이 될 수 있음을 시사한다.

-> DeepSeekR1과는 다른 결과네?

 

부정적 그래디언트 사용의 효과: 정책 최적화 알고리즘으로 ReST와 이 연구의 방법을 비교했을 때, 이 연구의 방법이 더 나은 샘플 효율성을 보여줌. ReST는 현재 모델에서 샘플링된 최상(긍정적 답변)의 응답에만 맞추는 반면, 이 연구의 방법은 부정적 답변도 그래디언트에 반영되어 명시적으로 패널티를 준다. 이는 긴 CoT 생성 능력을 효과적으로 향상시키는 데 중요했으며, 더 적은 훈련 샘플로도 강력한 성능을 달성할 수 있다.

 

샘플링 전략: 커리큘럼 샘플링 전략의 효과를 검증. 초기에 일반적인 데이터셋으로 웜업 단계를 진행한 후 어려운 문제에 집중하는 접근법이 균일한 샘플링 전략보다 우수한 성능을 보임.