반응형
- 논문 링크: https://arxiv.org/pdf/2212.09748
- 프로젝트 페이지: https://www.wpeebles.com/DiT
- 깃허브: https://github.com/facebookresearch/DiT

0. Abstract
- 본 논문에서는 트랜스포머 (Transformer) 구조를 백본 (Backbone)으로 하는 새로운 확산 모델 (Diffusion Model) 제안
- 이미지에 대한 잠재 확산 모델 (Latent Diffusion Model, LDM)을 학습 → 일반적으로 사용되는 U-Net 모델 대신 잠재 패치에서 작동하는 트랜스포머 사용
- 또한 본 논문에서 제안하는 DiT (Diffusion Transformers)의 확장성에 대해 분석
- 입력 토큰 수의 증가에 따라 트랜스포머의 깊이/너비가 증가할수록 낮은 FID를 가짐
- 가장 큰 모델인 DiT-XL/2의 경우 기존의 모든 확산 모델의 성능을 뛰어넘음 → ImageNet 512x512와 ImageNet 256x256 벤치마크에서 SoTA FID인 2.27 달성
1. Introduction
- 최근 머신러닝에서는 자연어처리, 비전 외 다양한 분야들에 트랜스포머 모델을 주로 사용
- 하지만 이미지 생성에서는 트랜스포머보다는 합성곱 (Convolution) 기반의 U-Net 구조의 기반 모델을 주로 사용하는 경향이 있음
- 하지만 본 논문에서는 U-Net의 inductive bias가 확산 모델의 성능에 중요하지 않다는 것을 보임 → 이를 통해 U-Net을 트랜스포머 같은 일반적인 디자인으로 대체하는 것이 가능함
- 이렇게 일반적인 디자인의 모델 구조를 사용하는 경우 아래와 같은 장점들이 있음!
- 다른 분야에서 좋은 성능을 보인 학습 방법을 적용할 수 있음
- 확장성, 강인성, 효율성과 같은 특성을 유지하는 것이 가능
- 이렇게 트랜스포머를 기반으로 하는 새로운 확산 모델을 제안 → Diffusion Transformers (DiT)
- DiT는 ViTs (Vision Transformers)의 장점을 가지고 있음 → 기존 합성곱 네트워크 (e.g. ResNet)에 비해 시각적 인지에 대해 더욱 효율적으로 확장 가능
- 트랜스포머의 확장성
- 네트워크의 복잡도 (Gflops로 측정)와 샘플의 품질 (FID로 측정) 사이에 강한 상관관계가 있음을 의미
- DiT는 VAE의 잠재 공간 내에서 확산 모델을 학습하는 LDM (Latent Diffusion Models) 방식 사용
- 단순히 DiT를 확장하여 높은 용량의 백본 모델 (118.6Gflops)를 통해 LDM을 학습 → 클래스 조건부 (Class Conditional) 256x256 ImageNet 생성 벤치마크에서 SoTA 성능 달성 = 2.27 FID
2. Diffusion Transformers
2.1. Preliminaries
Diffusion Formulation
- 본 논문의 구조를 살펴보기 전에 확산 모델 (DDPMs)를 이해하기 위한 기본 개념들을 알아보자
- 순방향 노이징 과정 (forward noising process)
- 가우시안 확산 모델은 점진적으로 실제 데이터에 노이즈를 추가하는 것을 가정
- $x_0 : q(x_t | x_0) = \mathcal{N} (x_t ; \sqrt{\bar{\alpha}} x_0 , (1-\bar{\alpha}_t)I)$
- $\bar{\alpha}$: 상수 하이퍼파라미터
- Reparameterization 기법을 사용하면 다음과 같이 샘플링 수행 가능
- $x_t = \sqrt{\bar{\alpha}}x_0 + \sqrt{1-\bar{\alpha}} \epsilon_t$ 여기서 $\epsilon_t \sim \mathcal{N} (0, I)$
- 역방향 과정 (reverse process)
- 모델은 $x_0$의 log-likelihood의 variational lower bound를 통해 학습
- $\mathcal{L} (\theta) = -p (x_0|x_1) + \sum_t \mathcal{D}{KL} (q^* (x{t-1}|x_t, x_0)||p_{\theta} (x_{t-1}|x_t))$
- $q^*$와 $p_{\theta}$가 가우시안이므로 $\mathcal{D}_{KL}$은 두 분포의 평균과 공분산으로 표현 가능
- $\mu_{\theta}$를 노이즈 예측 네트워크 $\epsilon_{\theta}$로 reparameterize하여 모델은 단순히 예측된 노이즈 $\epsilon_{\theta}(x_t)$와 실제 샘플링 된 가우시안 노이즈 $\epsilon_t$ 사이의 평균 제곱 오차 (mean-squared error)를 통해 학습될 수 있음 → $\mathcal{L}{simple} (\theta) = || \epsilon{\theta}(x_t) - \epsilon_t ||_2^2$
- 그러나 학습된 역방향 과정의 공분산 $\Sigma_{\theta}$를 통해 확산 모델을 학습하기 위해서는 전체 $D_{KL}$항이 최적화 될 필요가 있음
- 이에 따라 본 논문에서는 Nichol and Dhariwal’s의 접근을 사용 → $\mathcal{L}{simple}$로 $\epsilon\theta$를 학습하고 전체 $\mathcal{L}$로 $\Sigma_{\theta}$를 학습
- $p_{\theta}$가 학습되면 $x_{t_{\max}} \sim \mathcal{N}$으로 부터 새로운 이미지가 샘플링 될 수 있고 reparameterization 기법을 통해 샘플링 $x_{t-1} \sim p_{\theta} (x_{t-1} |x_t)$가 수행될 수 있음
- 모델은 $x_0$의 log-likelihood의 variational lower bound를 통해 학습
Classifier-free Guidance
- 조건부 확산 모델 (Conditional diffusion models)는 추가적인 입력 정보를 가짐 (e.g. 클래스 라벨 $c$)
- 이 경우 역방향 과정은 $p_{\theta} (x_{t-1} | x_t, c)$가 되며 $\epsilon_{\theta}$와 $\Sigma_{\theta}$는 $c$를 조건으로 가짐
- 이런 세팅에서 classifier-free guidance를 활용하여 샘플링 과정에서 $\log p(c|x)$가 높은 $x$를 찾을 수 있도록 함
- 베이즈 규칙 (Bayes rule)에 따라 $\log p(c|x) \propto \log p(x|c) - \log p(x)$가 되며 이로 인해 $\bigtriangledown _x \log p(c|x) \propto \bigtriangledown _x \log p(x|c) - \bigtriangledown_x \log p(x)$가 됨
- 확산 모델의 출력을 score 함수로 간주 → DDPM의 샘플링 과정은 다음 식에 의해 높은 $p(x|c)$를 가지는 $x$를 샘플링하도록 유도될 수 있음 → $\hat{\epsilon}{\theta} (x_t, c) = \epsilon{\theta} (x_t, \emptyset) + s \cdot \triangledown_x \log p (x|c) \propto \epsilon_{\theta} (x_t, \emptyset) + s \cdot (\epsilon_{\theta} (x_t, c) - \epsilon_{\theta} (x_t, \emptyset))$
- 여기서 $s>1$은 가이드의 정도를 나타냄 ($s=1$은 일반적인 샘플링과 동일)
- $c=\emptyset$으로 확산 모델을 평가하는 것은 학습 동안 랜덤하게 $c$를 드랍 아웃하고 이를 학습된 “null” 임베딩 $\emptyset$으로 대체하는 것
- Classifier-free guidance는 월등히 개선된 샘플을 생성하기 위한 기법으로 널리 알려져있음
Latent Diffusion Models
- 확산 모델을 직접적으로 고해상도 픽셀 공간에서 학습하는 것은 계산적으로 매우 비효율적
- LDM (Latent Diffusion Models)은 이런 문제를 2단계의 접근 방법으로 해결
- 오토인코더 (AutoEncoder) 학습 → 학습된 인코더 (Encoder) $E$를 통해 이미지를 더 작은 지역적인 표현 (spatial representations)으로 압축
- 이미지 $x$ $x$$z=E(x)$를 통해 확산 모델을 학습 (이때 $E$는 학습하지 않음) → 학습 이후 확산 모델을 통해 표현 $z$를 샘플링하고 이를 학습된 디코더를 통해 이미지로 디코딩 $x = D(z)$
- 그림 2에서 볼 수 있듯이 LDM은 좋은 성능을 달성하면서도 ADM과 같은 픽셀 공간 확산 모델에 비해 훨씬 적은 Gflops를 사용

- 이에 따라 본 논문에서도 DiT를 잠재 공간에서 적용
2.2. Diffusion Transformer Design Space
- 본 논문에서는 새로운 구조의 확산 모델인 DiT (Diffusion Transformers)를 제안
- DiT는 패치 (patch)의 시퀀스를 기반으로 동작하는 ViT (Vision Transformers) 구조를 기반으로 함
- 전체적인 DiT의 구조는 그림 3 참고

Patchify
- DiT의 입력은 공간적인 표현인 $z$
- 256x256x3 이미지에 대해 $z$는 32x32x4의 차원을 가짐
- DiT의 첫번째 레이어는 “패치화 (patchify)”를 수행
- 공간적인 입력을 $T$개의 토큰의 시퀀스로 변환
- 각각은 $d$ 차원을 가지며 입력의 각 패치를 선형적으로 임베딩
- 모든 입력 토큰들에 대해 일반적인 ViT의 positional embedding을 적용
- 패치와에 의해 생성된 토큰의 수 $T$는 패치 크기에 대한 하이퍼파라미터 $p$에 의해 결정됨
- 그림 4에서 볼 수 있듯이 $p$를 절반으로 하면 $T$는 4배가 됨 → 트랜스포머 연산량도 최소 4배 증가

- 본 논문에서는 DiT 디자인에 $p=2, 4, 8$ 추가
DiT Block Design
- 패치화 이후 입력 토큰들을 연속적인 트랜스포머 블럭들로 연산 수행
- 이때 확산 모델은 노이즈 이미지 입력 뿐 아니라 추가적인 조건 정보들을 사용 → 노이즈 시간 스텝 $t$, 클래스 라벨 $c$, 자연어 등..
- 본 논문에서는 이런 조건 입력들을 처리하기 위한 다른 구조 4가지를 소개
- 이 4가지 구조들도 위의 그림 3에서 살펴볼 수 있음
- In-context conditioning
- $t, c$의 벡터 임베딩을 단순히 입력 시퀀스에 두개의 추가적인 토큰으로 추가
- 이미지 토큰과 차이가 없다고 간주하고 사용 → 일반적인 ViT 블럭을 변경 없이 사용
- 마지막 블럭 이후 시퀀스로부터 조건에 대한 토큰을 제거
- 해당 방식은 무시할만큼 적은 크기의 새로운 Gflops가 연산에 추가
- Cross-atttention block
- 이미지 토큰 시퀀스와 분리한 상태로 $t, c$의 임베딩을 연결 (concatenate)
- 트랜스포머 블럭에는 멀티 헤드 셀프 어텐션 블럭 이후에 추가적인 멀티 헤드 크로스 어텐션 레이어가 추가됨
- 조건을 크로스 어텐션 레이어에 추가하는 방식 사용
- 크로스 어텐션은 모델에 가장 많은 Gflops를 추가 → 거의 15%의 오버헤드 발생
- Adaptive layer norm (adaLN) block
- 트랜스포머의 레이어 정규화 (layer norm) 층들을 adaLN (Adaptive Layer Norm)으로 대체
- 직접 스케일과 이동 (shift) 파라미터 $\gamma, \beta$를 학습하는 대신 $t, c$의 임베딩 벡터들의 합으로 부터 파라미터들을 취득
- adaLN은 가장 적은 Gflops를 추가 → 계산 효율적
- adaLN-Zero block
- ResNet에서 각 residual 블럭을 항등 함수 (identity function)으로 초기화하는 것이 좋다는 것을 보임
- 예를 들어, Goyal 등은 각 블록에서 마지막 배치 정규화(Batch Norm)의 스케일 계수 $\gamma$를 0으로 초기화하면, 대규모 지도 학습 환경에서 학습 속도가 빨라진다는 것을 발견
- 확산 모델의 U-Net은 유사한 초기화 전략을 사용 → 모든 residual 연결 이전의 마지막 각 블럭의 합성곱 레이어를 0으로 초기화
- 본 논문의 adaLN DiT 블럭도 동일한 변경을 수행
- $\gamma, \beta$에 대한 회귀 외에도 추가적으로 DiT 블럭 내에서 residual 연결이 적용되기 전 사용되는 차원 방향 스케일 파라미터 $\alpha$ 또한 함께 회귀
- 모든 $\alpha$에 대해 0 벡터를 출력하도록 MLP를 초기화 → 이는 모든 DiT 블럭을 항등 함수로 초기화하는 것.
- adaLN-Zero는 무시할만한 정도의 Gflops를 모델에 추가
- ResNet에서 각 residual 블럭을 항등 함수 (identity function)으로 초기화하는 것이 좋다는 것을 보임
Model Size
- 본 논문에서는 $N$개의 DiT 블럭의 시퀀스를 적용 → 각각은 은닉 차원 크기 $d$를 사용
- 일반적인 ViT와 동일하게 모델 크기에 따라 $N, d$, 어텐션 헤드를 모두 증가
- 크기에 대한 4개의 설정 사용 → DiT-S, DiT-B, DiT-L, DiT-XL
- 표 1을 통해 각 설정의 구체적인 수치를 살펴볼 수 있음

Transformer Decoder
- 마지막 DiT 블럭 이후 이미지 토큰의 시퀀스를 디코딩하여 아래 2가지를 출력
- 예측된 노이즈
- 예측된 대각 공분산 (diagonal covariance prediction)
- 이 두 출력은 원본 입력과 같은 크기를 가짐
- 일반적인 선형 디코더를 사용 → 마지막 레이어 정규화를 적용 (adaLN의 경우 adaptive 적용)하고 선형적으로 각 토큰을 $p \times p \times 2C$ 텐서로 디코딩 ($C$는 채널의 수)
- 최종적으로 디코딩된 토큰들을 원래의 공간적 배열로 재구성하여 예측된 노이즈와 공분산 취득
3. Experimental Setup
- 모델의 이름은 모델의 설정과 잠재 패치 크기 $p$를 사용하여 결정
- 예시: DiT-XL/2는 XL 사이즈 설정이면서 $p=2$인 경우
Training
- 클래스 조건부 잠재 DiT 모델을 ImageNet 데이터셋의 256x256, 512x512 해상도 이미지들에 대하여 학습 수행
- 마지막 선형 레이어는 0으로 초기화하고 다른 레이어들에 대해서는 ViT의 가중치 초기화 기법 사용
- 모든 모델은 AdamW를 사용하여 학습
- 학습 관련 설정
- 학습률 = $1 \times 10^{-4}$
- 배치 사이즈 = 256
- 학습률에 대한 warm up, 정규화, weight decay 사용하지 않음 → 적용하지 않아도 안정적 학습 수행
- 학습 동안 DiT의 가중치에 대해 EMA (Exponential Moving Average)를 decay 0.999로 수행
Diffusion
- Stable Diffusion에서 사용하는 사전학습 된 VAE (Variational AutoEncoder) 사용
- VAE의 다운샘플링 인수는 8로 설정
- 256x256x3의 RGB 이미지 $x$ → $z=E(x)$는 32x32x4의 크기를 가짐
- 모든 실험에 대해 확산 모델은 $\mathcal{Z}$-공간에서 작동
- 확산 모델로 부터 새로운 잠재 변수를 샘플링 한 이후 VAE 디코더를 사용하여 이를 픽셀로 디코딩 → $x=D(z)$
- ADM의 확산 모델 하이퍼파라미터를 그대로 사용
Evaluation Metrics
- FID (Frechet Inception Distance)를 사용하여 모델 크기 확장에 따른 성능 변화를 측정
- Inception Score, sFID, Precision/Recall은 보조적인 지표로 사용
Compute
- Jax를 사용하여 모든 모델을 구현
- TPU-v3를 통해 학습 수행
- 가장 큰 모델인 DiT-XL/2의 경우 거의 5.7 iter/sec의 속도로학습 했으며 TPU v3-256 pod에서 글로벌 배치 크기 256으로 학습
4. Experiments
DiT Block Design
- 서로 다른 블럭 디자인에 따라 4개의 높은 Gflop을 가지는 DiT-XL/2 모델 학습
- In-context (119.4 Gflops)
- Cross-attention (137.6 Gflops)
- Adaptive Layer Norm - adaLN (118.6 Gflops)
- adaLN-zero (118.6 Gflops)
- 그림 5를 통해 결과를 살펴볼 수 있음

- adaLN-Zero 블럭이 가장 계산 효율적이면서 cross-attention, in-context 보다 낮은 FID를 가짐
- in-context에 비해 절반 정도의 FID 값을 가질만큼 차이를 보임 → 조건을 제공하는 방식이 모델의 품질에 큰 영향을 미치는 것을 알 수 있음
- 초기화 또한 성능에 중요한 영향을 미침
- 각 DiT 블럭을 항등 함수로 초기화한 adaLN-Zero는 일반 adaLN의 성능을 크게 능가 → 이에 따라 앞으로 논문의 모든 모델은 adaLN-Zero DiT 블럭을 사용
Scaling Model Size and Patch Size
- 모든 모델 설정 (S, B, L, XL)과 패치 사이즈 (8, 4, 2)에 대해 12개의 DiT 모델들을 학습
- DiT-L과 DiT-XL은 다른 설정들에 비해 상대적인 Gflops가 훨씬 가까움
- 그림 2의 왼쪽 그래프를 통해 400K 만큼 학습한 각 모델의 Gflops를 살펴볼 수 있음

- 모든 경우에서 모델의 크기를 증가시키고 패치 사이즈는 감소시킬수록 확산 모델의 성능을 크게 향상시킴
- 그리고 그림 6의 위쪽을 보면 모델의 크기를 증가시키고 패치 크기는 고정했을 때 어떻게 FID가 변하는지 확인할 수 있음 → 트랜스포머가 더 깊고 넓어질수록 모든 학습 단계에서 FID가 개선됨

- 또한 그림 6의 아래쪽을 보면 패치 사이즈가 감소하고 모델의 크기가 고정되었을 때 FID를 확인할 수 있음
- 단순히 DiT에 의해 처리되는 토큰의 수만 확장시켜서 학습해서 FID가 크게 개선됨
DiT Gflops are Critical to Improving Performance
- 그림 6의 결과를 통해 파라미터의 수만이 DiT 모델의 품질에 영향을 미치는 것이 아님을 알 수 있음
- 모델의 크기를 고정하고 패치 사이즈를 감소시키는 경우 전체 파라미터 수는 변하지 않고 Gflop만 증가하게 됨
- 이에 따라 모델의 Gflops를 증가시키는 것이 실제로 성능 향상의 열쇠임을 알 수 있음! 🔑
- 이를 더욱 알아보기 위해 400K 학습을 수행했을 때 모델 Gflops의 변화에 따른 FID-50K의 그래프를 보여줌 → 그림 8을 통해 확인 가능

- 전체 Gflops가 유사한 경우 DiT의 설정이 달라도 유사한 FID 값을 가짐 (e.g. DiT-S/2와 DiT-B/4)
- 이를 통해 모델 Gflops와 FID-50K가 강력한 음의 상관관계를 가지는 것을 알 수 있음 → 추가적인 모델 연산이 DiT 모델 개선의 중요한 요소!
- 그림 12를 확인하면 Inception Score 같은 다른 점수들에서도 동일한 경향이 나타나는 것을 알 수 있음

Larger DiT Models are More Compute-Efficient
- 그림 9를 통해 모든 DiT 모델 전체 학습 연산의 함수에 대한 FID 그래프를 확인할 수 있음

- 학습 연산을 “모델 Gflops · 배치 사이즈 · 학습 스텝 · 3”으로 계산
- 작은 DiT 모델의 경우 오래 학습을 해도 더 적게 학습한 큰 DiT 모델에 비해서 상대적으로 계산이 비효율적
- 유사하게 패치 사이즈 외에 다른건 동일한 모델들도 학습 Gflops를 조절하여 다른 성능을 가지게 할 수 있음
- 예를 들어 $10^{10}$ Gflops 이후 XL/4가 XL/2의 성능을 능가하는 것을 알 수 있음
Visualizing Scaling
- 모델 확장에 따른 샘플의 품질을 시각화한 결과를 그림 7을 통해 살펴볼 수 있음

- 12개의 DiT 모델 각각으로부터 샘플링 한 이미지들을 확인할 수 있으며 모두 동일한 노이즈 $x_{t_{\max}}$와 클래스 라벨을 사용
- 어떻게 모델을 확장하는 것이 DiT의 샘플 품질에 영향을 미치는지 시각적으로 확인 가능
- 모델의 크기와 토큰의 수를 늘리는 것 모두 시각적 품질 개선에 큰 영향을 미침
4.1. State-of-the-Art Diffusion Models
256 x 256 ImageNet
- 높은 Gflop 모델로 지속적으로 학습 수행 → DiT-XL/2로 7M 스텝 학습
- 해당 모델을 통해 샘플링 한 결과를 그림 1에서 살펴볼 수 있음
- 표 2를 통해 최신의 클래스 조건부 생성 모델과 성능 비교 결과를 확인할 수 있음

- Classifier-free guidance를 사용한 경우 DiT-XL/2가 FID-50K에서 기존 최고 성능을 보인 LDM (FID-50K=3.6)보다 더 좋은 성능을 보이는 것을 확인 (FID-50K=2.27)
- 본 논문의 기법인 DiT가 기존 최고 성능을 보인 StyleGAN-XL을 포함한 모든 이전의 생성 모델보다 낮은 FID를 가짐
- 또한 DiT-XL/2는 LDM-4나 LDM-8에 비해 높은 recall을 달성했으며 오직 2.35M 스텝 학습만으로 XL/2는 모든 이전의 확산 모델 성능을 뛰어넘음 (FID=2.55)
512 x 512 ImageNet
- ImageNet 512 x 512 해상도에서 새로운 DiT-XL/2 모델을 3M 만큼 학습 → 256 x 256 모델과 동일한 하이퍼파라미터 사용
- 패치 사이즈 2를 사용하여 XL/2 모델은 64 x 64 x 4의 잠재 입력을 패치화하면 1024 토큰이 처리됨 (524.6 Gflops)
- 표 3은 최신 모델들과의 성능 비교 결과를 보임

- XL/2는 또 다시 모든 이전의 확산 모델 성능을 뛰어넘음
- 기존 최고 성능의 모델인 ADM의 FID=3.85보다 더 뛰어난 3.04의 FID 달성
- 토큰의 수가 늘어나도 XL/2는 여전히 계산 효율적 (ADM은 1983 Gflops, ADM-U는 2813 Gflops를 사용하는 반면 XL/2는 524.6 Gflops 사용)
- 그림 1의 고해상도 이미지가 해당 XL/2 모델로 샘플링 한 결과들
4.2. Scaling Model vs. Sampling Compute
- 더 작은 모델의 DiT가 더 많은 샘플링 연산을 수행하면 큰 모델의 성능을 뛰어넘을 수 있을까?
- 400K 학습 스텝 이후 모든 12개의 DiT 모델에 대해 FID를 계산 → 이미지 당 [16, 32, 64, 128, 256, 1000] 샘플링 스텝 사용
- 그림 10에서 결과를 살펴볼 수 있음

- DiT-L/2가 1000 샘플링 스텝을 사용한 결과 vs. DiT-XL/2가 128 스텝을 사용한 결과 비교
- L/2가 각 이미지 생성에 80.7 Tflops를 사용한 반면 XL/2는 5x 적은 15.2 Tflops 사용
- 그럼에도 불구하고 XL/2가 더 나은 FID-10K (23.7 vs. 25.9)를 달성
- 이를 통해 샘플링 계산을 증가시키는 것은 모델의 부족한 연산을 보상하지 못한다는 것을 알 수 있음
5. Conclusion
- 본 논문에서는 새로운 DiTs (Diffusion Transformers)를 소개
- 확산 모델에 단순한 트랜스포머 기반 백본을 사용한 모델
- 기존 U-Net 기반 모델의 성능을 뛰어넘었으며 트랜스포머 모델이 가지는 확장성 특성을 그대로 가짐
- Future work - 더 큰 모델과 토큰 수를 사용하여 DiT를 확장하는 것
반응형