본문 바로가기
논문 리뷰/Reinforcement Learning

[R2D2] Recurrent Experience Replay in Distributed Reinforcement Learning

by 박사개구리 2024. 9. 10.

Paper Link: https://openreview.net/pdf?id=r1lyTjAqYX

1. Abstract

  • 최근에 강화학습에서의 분산 학습의 성공에 따라, 분산된 PER(Prioritized experience replay)로 부터 RNN기반의 RL agents를 학습 시키기 위한 방법을 탐구했다.
  • We study the effects of parameter lag resulting in representational drift and recurrent state staleness and empirically derive an improved training strategy.
  • representational drift (표현 이동)과 recurrent state staleness (recurrent state의 부패[update된 network와 과거의 network에서 얻어진 recurrent state간 차이가 심각함.])의 결과에서 parameter lag (actor들의 증가에 따라 업데이트의 딜레이가 미약하게라도 발생할 수밖에 없음.)의 영향에 대해 연구했고, 실험적으로 향상된 학습전략을 유도했다.

2. Introduction

  • 도적적인 문제(Atari에서의 human-level, Alpha go, Dota etc.)들을 계속해서 풀어가며, 강화학습 연구의 관심은 다시 높아졌다.
  • 초기 연구에서 experience replay를 통해 데이터 효율을 높이고, 연속적인 여러 frames을 쌓아 partial observability를 극복했다. 하지만, POMDP의 어려움을 더 극복하기 위해 RNN을 해결책으로 사용한다.
  • 이 논문에서는 RNN과 experience replay를 같이 사용할 수 있는 학습법에 대해 탐구했다.
  • 3가지 주요 기여가 있습니다.
    • 첫번째로, parameter lag, representational drift, recurrent state staleness에 대한 experience replay의 영향에 대해 증명했다. 이 부분들은 분산 학습 상황과 궁극적으로 줄어든 학습 안정성과 성능에 따라 잠재적으로 악화 되어진다.
    • 두번째, 위에서 나타난 문제들이 완화되어지는 효과와 experience replay를 적용한 RNN학습의 영향에 대해 실험적인 연구를 했다.
    • 세번째로, 상당한 발전을 위해 SOTA 모델인 Recurrent Replay Distributed DQN을 제안했다.

3. BackGround

3.1 Reinforcement Learning

  • Partially Observable Markov Decision Process (POMDP)
    • tuple(S, A, T, R, Ω, Ο)Ω: agent가 받을 수 있는 observations.
    • Ο: observation function.
    • S: states, A: actions, T: transition function, R: reward function
    • 여기서 agent가 받을 수 있는 observation o는 Ω에 포함된다. 또한 state에 대한 부분 정보들이다.
    • POMDP의 true state를 확실히 표현 할 수 있도록 학습하기 위해, RNN을 사용하기로 했다.
    • 대표적으로 A3C에 LSTM을 사용한 방법 (하지만, reply buffer는 사용하지 않음)과 zero start state를 사용한 DRQN(replay buffer 사용)이 있다.

3.2 Distributed Reinforcement Learning

  • Ape-X
  • PER, n-step TD, dueling DQN, Double DQN을 learner와 actor로 분리해 Distributed setting을 통해 학습한 알고리즘이다.
  • IMPALA
  • on-policy n-step에서 사용되는 V-trace를 사용한다. queue를 사용하며, 독립된 actor들을 통해 initial recurrent state를 따라 생성된 sequence를 experience queue에 저장한다. actor-learner 관계로 queue에 actor들이 쌓아놓은 experience를 순차적으로 처리한다. 모든 데이터를 다 사용함.

3.3 The Recurrent Replay Distributed DQN Agent

  • n-step = 5, actor = 256, batch = single learner에 따른 replayed experience를 나눈다.
  • 일반적인 transition tuple을 (s, a, r, s'), fixed-length (m = 80) sequences of (s,a,r)로 저장하며, 가까운 sequence들이 겹쳐 사용 된다. 40 time steps. (겹쳐서 반절.) 그리고 에피소드의 경계를 가로지르지 않는다. (게임 종료인 end를 넘어서, 다른 episode의 transition으로 연결해 사용하지 않음.)
  • reward에 대해선 특별히 clip을 적용하지 않았지만, value function을 rescaling 하는 방식을 사용했습니다. (아래 수식을 적용했습니다.)

  • p는 priority value 이며, max and mean absolute n-step TD-erros δ 이며, δ'은 전체mean, η는 priority exponent로 0.9, γ = 0.997
  • store([s,a,r,s'],[s,a,r,s'],[s,a,r,s'])
  • store([s,a,r....40,p],[s,a,r....40],[s,a,r....40])

4. Training Recurrent RL Agents with Experience Replay

  • Partially observed environment에서도 좋은 성능을 달성하기 위해, RL agent는 state-action trajectory와 현재의 observation 사이의 정보인 state representation을 해석해야한다. 이를 위해 가장 좋은 방법은 RNN을 사용 하는 것이다.
  • 기존에는 RNN과 experience replay를 같이 사용하기 위해 2가지 전략을 취했다.
    1. Using a zero start state to initialize the network at the beginning of sampled sequences.
    • Experience Replay에서 샘플을 뽑고, zero start state(RNN hidden을 0으로 시작한다는 의미.)로 학습하는 방식이다.
    1. Replaying whole episode trajectories.
    • 전체 episode trajectories에 대해 학습하는 방식.
  • 2가지 전략의 단점.
    1. 첫번째 방식의 단점은 RNN이 학습되면서 초기 RNN state와 맞지 않고 멀어지는 일이 발생한다. 즉 맨 첫 episode의 시작점을 기준으로 zero start state를 하는 것이 아닌, sampling된 어느 시점을 기준으로 시작해, 학습이 되면서 점점 ('initial recurrent state mismatch')가 발생한다.
    2. 두번째 방식의 단점은 전체 episode trajectories를 학습하다보니, 실용, 계산, 알고리즘 문제를 발생시킨다. environment-dependent 하다보니, update의 variance가 첫번째 방식보다 높아지고, random sampling보다 nature of states에 영향을 많이 받는다.

  • 위 2가지 단점을 해결하기 위해, RNN과 random sampling이 가능한 Experience replay를 학습할 수 있는 2가지 전략을 소개 및 평가 했다.
    1. Stored state
      • RNN의 hidden state를 Experience replay에 같이 저장하는 방식입니다. 이후 학습에도 저장된 hidden state를 initial state로 사용해 zero start state 전략의 단점을 보완합니다. 하지만, 오랜된 experience가 sampling 될 경우 현재 업데이트 되고 있는 network와 차이가 심해 'representational drift'(네트워크가 표현하는 내용이 바뀌는 상황)과 'recurrent state staleness(hidden state 자체가 학습에 도움이 안되는 상황)'이 발생하게 됩니다.
      • 하지만 burn-in 전략에 비해 state staleness를 더 완화했습니다. (Q-value discrepancy 실험을 통해 알게된 점.) 성능 향상에 있어서도 일관적으로 개선되었습니다.
    2. Burn-in
      • 전체 trajectory 중 일정 부분을 'burn-in period(네트워크를 예열한다고 생각)' 단계의 과정으로 사용합니다. 예를 들어 총 80의 sequence라면 앞쪽의 30개 sequence로 RNN에 입력하고 30step 동안 변화된 RNN의 hidden state를 initial state로 사용해 남은 50개 sequence에 대해 학습하는 전략입니다. 마찬가지로 zero start state의 단점을 보완하지만 여전히 'recurrent state staleness'문제는 해결하지 못했습니다.
      • 실험적으로 확인한 결과 burn-in 방식은 'destructive updates'를 방지합니다. zero state initialization의 초기 ouput이 매우 부정확한 상태인걸 보완합니다.
    3. both stored state and burn-in strategy.
      • 2가지 방식을 서로 조합했을 때, 상당한 이점이 있었다.

5. Effect of Distributed RL Agent Training

  • Single learner를 위한 replay buffer에 experience 데이터를 채우는 actor의 수가 매우 많을 때 RNN을 사용하는 에이전트의 분산 학습의 영향을 알아보자
  • Distributed setting은 single actor를 사용하는 경우에 비해 representational drift의 문제가 덜함 (Hausknecht & Stone, 2015)
    • 이 이유는 많은 양의 경험들은 replay 되는 빈도가 적기 때문 (일반적으로 Ape-X에서는 각 경험이 한번 혹은 그 이하로 replay 되었으나 DQN의 경우 8번 정도 replay 됨)
    • 그러므로 분산 에이전트 학습은 'parameter lag'의 상승을 더 적게 유발함 (경험이 replay 되는 시점에서 해당 경험을 생성할 때 사용된 네트워크 파라미터가 얼마나 오래된 것인가?)
  • 분산 세팅은 하드웨어나 시간에 따라 연산 자원을 scaling하기 쉬움
    • 이상적인 분산 에이전트는 다음과 같은 사항들의 변화에 강인해야함 → Actor의 수, 파라미터 재튜닝
    • 이전 section에서 Replay를 사용한 RNN 학습은 representational drift 문제에 민감하다는 것을 확인 → 위와 같은 파라미터들에 따라 문제가 발생할 수 있음
  • 이런 효과들을 살펴보았을 때, 본 논문에서는 더 적은 수의 actor를 사용하는 구조를 제안
    • 이는 parameter lag과 연관이 있음
    • Actor의 수는 256에서 64로 변경 → parameter lag는 1500에서 5500개의 파라미터 업데이트로 변경됨 → representation drift와 recurrent state staleness의 정도에 영향을 미침
  • Figure 1의 왼쪽 칼럼을 보면 적은 수의 actor를 사용했을 때 replay sequence의 첫번째나 마지막 state 둘다 평균 $\bigtriangleup Q$가 증가
    • 위 결과를 통해 분산 세팅에서 학습 전략을 어떻게 설정하느냐가 중요하다는 것을 알 수 있음

6. Experimental Evaluation

  • R2D2의 성능을 Atari-57과 DMLab-30에서 비교
  • Atari: 네트워크의 구조와 hyper-parameter의 설정을 57개의 Atari 게임에 동일하게 설정 (DQN)
  • DMLab: 동일한 네트워크 구조와 hyper-parameter 설정 사용 → 좋은 robustness, generality를 보여줌
  • Atari와 DMLab에서 state-of-the-art 성능 보임

6.1 Atari-57

  • Figure 2의 왼쪽 그래프는 타 기법들과 R2D2의 성능을 모든 Atari 게임에서 median human-normalized score로 비교한 것
  • R2D2는 모든 single actor 알고리즘에 비해 좋은 성능을 보이며 SOTA 성능을 보이는 Ape-X와 비교했을 때도 거의 3배의 성능을 보임
    • Ape-X와 비교했을 때 더 적은 actor 사용 (360 → 256)
    • 높은 sample efficiency
    • 높은 time efficiency
  • Table1을 통해서도 R2D2가 Human-normalized score 측면에서 다른 알고리즘들에 비해 얼마나 좋은 성능을 보이는지 알 수 있음
  • 한가지 파라미터와 네트워크 구조로도 각 게임별로 보았을 때도 매우 좋은 성능을 보임 → Figure 2의 오른쪽에서 MS.PACMAN의 경우 기존의 최대 점수인 Van Seijen et al. (2017)보다 훨씬 좋은 성능을 보임 → 기존의 기법은 MS.PACMAN을 위해 특별히 engineered된 알고리즘
  • Ape-X의 경우 Rainbow와 동일하게 49개의 게임에서 super-human 성능을 보임
  • R2D2의 경우 57개의 게임 중 52개의 게임에서 super-human 성능을 보임
    • 넘지 못한 게임들은 모두 어려운 exploration 조건을 가진 환경들 (Montezuma's Revenge, Pitfall, Skiing, Solaris, Private eye)
    • Exploration 문제를 해결하는 알고리즘을 R2D2에 붙여주면 조만간 모든 57개의 Atari game에서 super-human 성능을 보이는 알고리즘을 만들 수 있을듯!

6.2 DMLab-30

  • DMLab은 3D 일인칭 게임 엔진으로 30개의 문제로 구성되어 있음
  • Atari의 경우 frame stacking만으로 대부분 해결이 가능했지만 DMLab-30 의 경우 적절한 성능을 얻기 위해서는 long-term memory가 요구됨
  • Experience replay에 long-term memory를 결합하는 것이 어려웠던 관계로 그동안 해당 환경에서 좋은 성능을 보인 에이전트들은 대부분 on-policy setting의 actor-critic 기반 알고리즘들이었음
  • R2D2가 처음으로 value-function 기반 알고리즘 중에서 SOTA 성능을 보인 알고리즘
  • Figure 3의 왼쪽을 보면 R2D2와 IMPALA의 성능을 비교함
    • R2D2가 IMPALA보다 작은 네트워크를 사용하고 모든 domain에서 동일한 hyper-parameter를 사용하면서도 훨씬 좋은 성능을 보이는 것을 알 수 있음
    • Table 1에서는 population-based training (PBT) 버전의 IMPALA와 R2D2를 비교한 결과를 볼 수 있으며 R2D2가 훨씬 좋은 최종 median performance를 보이는 것을 알 수 있음

7. Analysis of Agent Performance

  • Atari-57의 환경들은 대부분 fully observable (4 frame stack observation)이며 에이전트가 memory-augmented representation으로부터 이득을 얻을 것으로 기대되지 않았음
  • R2D2와 기존의 알고리즘인 Ape-X의 가장 큰 차이점은 RNN의 사용 여부 → R2D2가 Atari에서 이렇게나 큰 차이로 SOTA 성능을 보인 것이 놀라워요!
  • 이번 section에서는 LSTM 네트워크의 역할을 분석하며 높은 성능을 가지는 R2D2 에이전트의 학습 전략을 알아볼 것

  • 우선 LSTM이 R2D2의 성공에 얼마다 중요한 역할을 하는지 알아보기 위해 R2D2에 순수 feed-forward를 사용한 버전과 성능 비교를 진행함
    • Figure 4를 통해 보았을 때 LSTM을 사용하는 것이 학습 속도와 최종 성능 면에서 도움을 주는 것을 확인할 수 있으며 이것이 Ape-X와의 성능 차이를 설명해줌

  • 다음으로 R2D2의 성능이 memory에 얼마나 의존하는지 알아보자
    • 이를 위해 Atari에서는 MS.PACMAN을 사용 → 게임이 시각적으로 fully observable하지만 R2D2는 state-of-the-art 성능을 보임
    • DMLab의 EMSTM_WATERMAZE는 메모리의 사용을 매우 요구하는 환경
    • 각 게임에서 두개의 전략을 사용하여 학습 (zero state와 stored state)
    • 그리고 history length를 고정하여 에이전트의 성능 확인
    • Figure 5의 왼쪽을 보면 history의 길이를 무한대에서 0으로 감소시키면서 성능을 확인 → 작아질수록 성능이 감소하는 것을 확인할 수 있음
    • max-Q-values의 차이와 percentage of correct greedy actions의 결과도 보여줌
      • 첫번째로 memory를 제한하는 것이 성능을 점진적으로 감소시키는 것을 확인 → 메모리의 사용이 중요함!
        • 특히 stored state의 경우 full history를 사용하는 것이 큰 성능 향상을 보여줌
        • zero start states를 사용하는 경우 성능의 감소가 더 빨랐음
        • 이것은 zero start state가 메모리를 사용하여 학습하는 에이전트의 능력을 제한하는 것을 알 수 있음
        • 메모리의 사용을 많이 요구하는 EMSTM_WATERMAZE와 같은 환경에서 MS.PACMAN 같은 환경보다 메모리의 사용이 더 큰 영향을 미치는 것을 알 수 있음
        • 이를 통해 stored state가 zero state strategy보다 좋은 성능을 보임을 알 수 있으며 R2D2가 비슷한 알고리즘인 Reactor (Gruslys et al., 2018)보다 좋은 성능을 보이는 것이 설명됨
      • 마지막으로 Figure 5의 오른쪽과 중간 column을 통해 history length가 0으로 작아짐에 따라 Q-value와 greedy policy의 quality가 단조 감소하는 것을 확인할 수 있음
        • 메모리에 대한 제한이 에이전트 성능에 어떤 영향을 미치는지 확인해 볼 수 있음

8. Conclusions

  • 우선 본 논문의 결과로 두가지 놀라운 점을 찾아냄
  • 첫번째 놀라운 점
    • 기존의 연구들 (Hausknecht & Stone, 2015; Gruslys et al., 2018)에서 많이 사용되었던 zero state initialization이 action-values를 잘못 예측하는 것을 유발할 수 있음 (특히 replayed sequence의 초반 state들에 대해서)
    • 또한 burn-in 없이 BPPT를 통해 update 하는 것은 초기 time step들에 대해 좋지 않은 추정 결과를 도출할 수 있으며 나쁜 업데이트를 일으킬 수 있음
    • suboptimal한 초기 recurrent state에서 network의 능력을 회복하는 것을 방해
    • context dependent recurrent state를 replay를 replay에 저장하거나 burn-in을 위한 replayed sequences의 initial part를 사용하는 것이 좋음 → RNN이 recurrent state나 long-term temporal dependencies의 exploit에 의존하므로 두가지 전략을 적절하게 결합하는 것도 가능
    • 또한 분산 세팅에서 representational drift나 recurrent state staleness가 성능을 악화시킬 수 있다는 것을 확인 → RNN의 적절한 학습 전략을 사용해서 robustness를 확보하는 것이 중요!
  • 두번째 놀라운 점
    • 메모리를 사용하는 에이전트를 통해 LSTM으로 학습하는 것의 효과를 확인
    • LSTM 학습은 기존의 RL 연구들에서는 중요하게 많이 사용되지 않았지만 더 나은 representation learning을 제공함으로써 fully observable하여 메모리를 크게 요구하지 않는 환경에서도 큰 성능 향상을 보이는 것을 확인
  • 결론!
    • Atari-57과 DMLab-30에서 병렬화, 분산 환경을 사용하여 취득한 많은 경험 데이터를 통해 RL agents를 scaling up하였고 이를 통해 성능의 향상을 가져옴
    • 높은 sample complexity 때문에 기존에 학습을 위해 몇십억 스텝을 몇 일 동안 학습한 것에 비해 몇시간 만에 학습을 수행했다는 점이 인상깊음
    • Future work
      • sample efficiency를 통해 기존에 빠른 simulation이 어려웠던 도메인에 적용해볼 것
      • Exploration 기법을 추가적으로 도입하여 5개의 어려운 탐험 환경까지 super-human performance에 도달할 수 있도록 할 것

Appendix

Hyper-Parameters

  • R2D2는 DQN과 동일한 3 layer CNN을 사용하였고 CNN 이후에 512 hidden unit을 가지는 LSTM 사용
  • LSTM 연산 결과를 dueling network 구조인 advantage와 value head로 나누어줌 (각각 hidden layer size = 512)

Full Result (DMLab-30)

R2D2 Learning Curves on 57 Atari games

Performance of R2D2 (대조군: Reactor, IMPALA, Ape-X)