Paper Review/Knowledge Distillation

[Paper Review] Logit Standardization in Knowledge Distillation

hakk35 2024. 12. 1. 19:56

This is a Korean review of

"Logit Standardization in Knowledge Distillation"
presented at CVPR 2024.

TL;DR

  • KD에서 teacher와 student의 soft label (i.e., prediction)을 얻을 때 사용하는 shared temperature은 teacher와 student logits의 range와 variance의 mandatory exact match를 전제로 함. (in fact, relation is important.)
  • 기존 방법의 한계를 극복하기 위해, adaptive temperature로 weighted logit standard deviation을 사용함.
  • 이를 활용해, softmax를 적용하기 전, Z-score pre-process를 수행함. 이를 통해 student가 magnitude match가 아닌 핵심적인 logit relation에 집중하도록 함.

Expression Transformation

  Student prediction
Convential Form
(including temperature T)
q(zn)(k)=exp(zn(k)/T)m=1Kexp(zn(m)/T)
Constrained Entropy-Maximization Form
(with Lagrangian multipliers α,β)
q(zn)(k)=exp(βnzn(k))m=1Kexp(βnzn(m))
General Form
(with hyper-parameters aS,bS)
q(zn;aS,bS)(k)=exp[(zn(k)aS)/bS]m=1Kexp[(zn(m)aS)/bS]
Logit Standardization
(with mean zn, weighted standard deviation σ(zn))
q(zn;zn,σ(zn))(k)=exp(Z(zn;τ)(k))m=1Kexp(Z(zn;τ)(m))

 

 

Introduction

  • 본 논문은 classification과 KD에서 사용되는 softmax가 정보이론의 entropy maximization 원리에서 유도됨을 보이는데, 이 과정에서 temperature가 Lagrangian multiplier로부터 얻을 수 있음을 보임.
  • 이를 바탕으로, teacher와 student의 temperature 간의 무관성(irrelevance)뿐만 아니라, 서로 다른 sample의 temperature 간의 무관성을 규명하여, teacher와 student 간, 서로 다른 sample 간에 반드시 같은 temperature를 적용해야 할 필요가 없음을 보여줌.
  • Teacher와의 capacity gap으로 인해 teacher와 유사한 range와 variance를 가지는 logit을 예측하는 것이 어려운데, 이를 극복하기 위해 adaptive temperature로서 weighted logit standard deviation을 사용하며, Z-score logit standardization을 softmax 적용 전 pre-processing 단계로 제안함.
  • Z pre-processing를 통해 logit의 arbitrary range를 bounded range로 mapping 하여 student logit이 teacher logit의 innate relationship을 보존하고 학습하도록 함.

 

 

Related Work

  • 예측된 확률분포를 smooth 하게 만들기 위해 적용되는 temperature T은 hyper-parameter로서 사전에 지정되어야 하며, 학습되는 동안 고정된 값을 가짐.
  • CTKD는 adversarial learning을 적용해 sample 난이도에 따른 sample별 temperature를 활용했지만, teacher와 student가 동일한 temperature를 공유해야 한다고 가정함.
  • ATKD [paper to read]가 sharpness metric을 제안하고, adaptive temperature를 적용했지만, zero logit mean이라는 ATKD의 가정은 numerial apporximiation에 의존함.
  • 이전 연구를 통해, student와 teacher 간의 exact logits matcing 대신, prediction의 inter-class relation만으로도 충분하지만, 기존의 sharing temperature 적용은 여전히 implicit 하게 exact mathcing 하도록 만듦.

 

 

Background and Notation

  • Student's logit은 zn=fs(xn), teacher's logit은 vn=ft(xn).
  • Temperature T가 포함된 일반적인 softmax function은 아래의 식으로 표현됨.

q(zn)(k)=exp(zn(k)/T)m=1Kexp(zn(m)/T),q(vn)(k)=exp(vn(k)/T)m=1Kexp(vn(m)/T).

  • Knowledge distillation은 아래의 KL divergence를 최소화하여 q(zn)(k)q(vn)(k)를 모방하도록 함.

LKL(q(vn)q(zn))=k=1Kq(vn)(k)log(q(vn)(k)q(zn)(k))

  • 이론적으로 z에 대해서만 optimization 할 때, cross-entropy loss와 동일함.
    → teacher prediction는 학습 동안에 고정된 값이고, student prediction을 최적화하는 과정이기 때문에, v만을 포함하는 항을 상수로 간주하여 cross-entropy loss와 이론적으로 같다고 할 수 있음.

LCE(q(vn)q(zn))=k=1Kq(vn)(k)logq(zn)(k)

  • q(vn)의 negative entropy term로 인해 gradient가 diverge 하기 때문에 경험적으로는 같지 않음.
    v만을 포함하는 항이 상수지만, 해당 값으로 인해 loss의 값이 과도하게 커질 수 있기 때문에, 이로 인해 발생하는 하는 발산을 방지하기 위해 다른 방향으로 최적화과정이 진행돼서 cross-entropy loss와 다른 결과를 나타낸다는 뜻?

 

 

Method

  • 1)Teacher와 student 간의, sample 간의 temperature irrelevance2)기존 shared temperature의 두 가지 단점을 언급하고, 3)temperature안의 factor로서 logit standard deviation를 적용한 logit standardization의 pre-process를 제안함.

1) Irrelevance between Temperatures

  • KD와 classification에서의 temperature-involved softmax를 entropy-maximization principle를 통해 유도함. 이는 student와 teacher 간의, sample 간의 서로 다른 temperature 적용 가능성을 시사함.

Derivation of softmax in Classification

maxqL1=n=1Nk=1Kq(vn)(k)logq(vn)(k)

  • 첫 번째 제한조건은 discrete probability density 조건으로, 확률 분포 정의에 따라 합이 1이 되어야 함.

k=1Kq(vn)(k)=1,n

  • 두 번째 제한조건은 기댓값이 목표 클래스 yn의 로짓 값과 일치하도록 하여, 모델이 정확하게 target class를 예측하도록 함.

Eq[vn]=k=1Kvn(k)q(vn)(k)=vn(yn),n.

  • Lagrangian multipliers α1,n, α2,n를 적용하면 다음과 같이 식을 변형할 수 있음.

LT=L1+n=1Nα1,n(k=1Kq(vn)(k)1)+n=1Nα2,n(k=1Kvn(k)q(vn)(k)vn(yn))

  • α1,nα2,n에 대해서 부분 미분하면 constraints로 돌아가고, q(vn)(k)에 대해 미분을 하면, 아래의 식으로 정리됨.

LTq(vn)(k)=1logq(vn)(k)+α1,n+α2,nvn(k)

  • 미분값에 0을 취하면 solution을 얻을 수 있음.

q(vn)(k)=exp(α2,nvn(k))/ZT where ZT=exp(1α1,n)=m=1Kexp(α2,nvn(m))

Derviation of softmax in KD

  • Constrained entropy-maximization optimization는 아래와 같음.

maxqL2=n=1Nk=1Kq(zn)(k)logq(zn)(k)

  • 첫 번째와 두 번째 제한조건은 classification과 동일함.

k=1Kq(zn)(k)=1,n

k=1Kzn(k)q(zn)(k)=zn(yn),n

  • 세 번째 제한조건은, KD에 의해 student가 완전히 학습되었다고 가정하면 teacher logit과 student logit이 동일해야 하기 때문에 추가됨.

k=1Kzn(k)q(zn)(k)=k=1Kzn(k)q(vn)(k),n.

  • Lagrangian multipliers β1,n, β2,n,β3,n를 적용하면 다음과 같이 식을 변형할 수 있음.

LT=L2+n=1Nβ1,n(k=1Kq(zn)(k)1)+n=1Nβ2,n(k=1Kzn(k)q(zn)(k)zn(yn))+n=1Nβ3,nk=1Kzn(k)(q(zn)(k)q(vn)(k))

  • q(zn)(k)에 대해 미분하고, βn=β2,n+β3,n로 정의하면 아래의 solution을 얻을 수 있음.

LSq(zn)(k)=1logq(zn)(k)+β1,n+β2,nzn(k)+β3,nzn(k)

q(zn)(k)=exp(βnzn(k))/ZS where ZS=exp(1β1,n)=m=1Kexp(βnzn(m))

∴ Distinct Temperature

  • 각 constraints는 α 또는 β와 관련이 없음. 따라서, α2,nβn에 대한 explicit expression이 없으므로, manually 정의할 수 있음.
  • βn=α2,n=1/T로 정의하면, shared temperature를 적용하는 KD에서의 prediction으로 표현됨.
  • βn=α2,n=1로 정의하면, 식은 classification에서 흔히 사용되는 전통적인 softmax function이 됨.
  • 따라서, βnα2,n을 선택하면, student와 teacher에 서로 다른 온도를 적용할 수 있음.

Sample-wisely different Temperature

  • 일반적으로 모든 샘플에 대해서 global temperature를 정의하지만 (i.e., any n에 대해서 α2,n,βn은 고정값으로 정의), 이에 대한 제한조건이 없기 때문에 샘플에 따라 서로 다른 온도를 사용하는 것이 가능함.

 

2) Drawbacks of Shared Temperature

  • Entropy-maximization으로부터 유도한 식을, hyper-parameters aS,bS를 추가해, general form으로 만들 수 있음. (cf. aS=0,bS=1/βn을 적용하면 원래대로 돌아감.)

q(zn)(k)=exp(βnzn(k))m=1Kexp(βnzn(m))q(zn;aS,bS)(k)=exp[(zn(k)aS)/bS]m=1Kexp[(zn(m)aS)/bS]

q(vn)(k)=exp(α2,nvn(k))m=1Kexp(α2,nvn(m))q(vn;aT,bT)(k)=exp[(vn(k)aT)/bT]m=1Kexp[(vn(m)aT)/bT]

  • 이상적으로 학생이 완전한 정보를 전달받았다고 하면, KL divergence loss는 minimum에 도달하고, student의 확률분포는 teacher와 일치하게 됨. 즉, k[1,K],q(zn;aS,bS)(k)=q(vn;aT,bT)(k) 따라서, arbitrary pair i,j[1,K] 대해서 아래와 같이 식을 쓸 수 있음.

exp[(zn(i)aS)/bS]exp[(zn(j)aS)/bS]=exp[(vn(i)aT)/bT]exp[(vn(j)aT)/bT](zn(i)zn(j))/bS=(vn(i)vn(j))/bT

  • j에 대해서 average 하면, 즉, zn=1/Km=1Kzn(m),vn=1/Km=1Kvn(m)를 적용하면 아래와 같이 정리할 수 있음.

(zn(i)zn)/bS=(vn(i)vn)/bT

  • 위의 식을 제곱하여 i에 대해서 average 하면, input logit vector에 대한 standard deviation σ로 표현되는 아래의 식으로 정리됨.

σ(zn)2σ(vn)2=1Ki=1K(zn(i)zn)21Ki=1K(vn(i)vn)2=bS2bT2

  • 위의 식을 통해서, well-distilled student의 특성① logit shift② variance matching으로 나타낼 수 있음.

① Logit shift

(zn(i)zn)/bS=(vn(i)vn)/bT

zn(i)=vn(i)+Δn, where Δn=znvn

  • 기존의 shared temperature를 적용 (bS=bT)하면 student와 teacher의 logit사이의 constant shift Δn가 존재함. 즉, traditional KD는 student가 teacher의 shifted logit를 모방하도록 함. 
  • 하지만, 두 모델의 capacity 차이를 생각할 때, student는 teacher처럼 넓은 logit range를 얻을 수 없음. (capacity가 logit range에 영향을 미침? → "Improving distillation for large teacher" 참고)
  • 정확한 logit matching보다 logit rank만을 유지하면 되기 때문에, 기존 KD방법의 logit shift는 student에게 불필요한 어려움을 제공함.

② Variance match

σ(zn)σ(vn)=bSbT

  • 위 식은 temperature ratio와 standard deviation ratio가 동일함을 의미함. 기존 shared temperature bS=bTσ(zn)=σ(vn)가 되도록 강제하기 때문에, student logit의 standard deviation을 제한함.

 

3) Logit Standardization

  • 기존 shared temperature가 가지는 logit shift, variance match의 두 가지 단점을 극복하기 위해, aS,bS,aT,bT를 각 logit의 mean zn과 weighted standard deviation σ(zn)으로 대체하여 Algo. 1과 같이 weighted Z-score function을 구할 수 있음. 

q(zn;zn,σ(zn))(k)=exp(Z(zn;τ)(k))m=1Kexp(Z(zn;τ)(m))

q(vn;vn,σ(vn))(k)=exp(Z(vn;τ)(k))m=1Kexp(Z(vn;τ)(m))

  • Z-score standardization를 사용하면, ① zero mean, ② finite standard deviation, ③ monotonicity, ④ boundedness의 장점이 있음.

① Zero mean

  • Z-score function를 적용하면 standardized vector의 평균이 0이 됨.

② Finite standard deviation

  • Weighted Z-score output의 standard deviation은 1/τ과 동일함.
  • Standardized student와 teacher logit을 zero mean과 definite standard deviation을 가지는 Gaussian-like 분포로 표현가능함.
  • Standardization의 mapping은 many-to-one이기 때문에 그 반대는 정의되지 않음. 즉, 기존 student logit vector의 variacne와 value range는 제한 없음.

③ Monotonicity

  • Z-score는 linear transformation function이기 때문에 monotonic function임. 즉, standardized student logit은 기존의 student logit과 같은 rank를 가짐.
  • Teacher logit 내의 필수적인 고유 관계가 보존됨.

④ Boundedness

  • Standardized logit은 [K1/τ,K1/τ]으로 bounded 되며 이를 통해, 과도하게 큰 값을 피할 수 있음. logit range를 조절하기 위해 base temperature를 정의함.

Toy Case

  • S1는 teacher prediction을 magnitude 측면에서 더욱 유사하게 예측했고, S2는 teacher의 rank를 그대로 유지했음. S1의 경우 S2보다 더 작은 KL divergence loss를 얻었지만, S1는 잘못된 예측을 했고, S2는 올바른 예측을 함. 이를 통해 loss 비교의 모순을 알 수 있음.
  • Z-score를 적용하면, 모든 logit이 re-scaled 되고 magnitude보다 relation이 강조됨. 정규화된 후, loss는 S2 S1보다 낮아지는 것을 확인할 수 있음.

 

 

Experiments

Main Results

CIFAR-100

ImageNet

Ablation Study

  • KD loss의 weight가 증가할수록, vanilla KD의 성능은 급격히 하락하는 것에 반해, Z-score pre-process는 향상된 성능을 얻을 수 있음.

 

Extensions

Logit range

  • 기존 KD를 적용하면, target index에 대해서 student가 teacher만큼의 large logit을 가질 수 없는 반면, 본 논문에서의 방법을 적용하면, 적절한 range의 logit을 만들어 teacher을 잘 모사하도록 함.

Logit variance

  • 기존 KD는 student logit의 variance가 teacher의 variance로 접근하도록 하지만, 본 논문 방식은 student logit이 flexible logit variance를 가지도록 함. standardized logit은 teacher와 동일한 variance를 가짐.

Feature visualizations

Improving distillation for large teacher

  • 큰 teacher가 언제나 좋은 teacher를 의미하는 것이 아니며, 이는 teacher와 student 간의 capacity gap으로부터 기인한다고 설명되어 옴.
  • 본 논문에서는 이를 student가 teacher와 동일한 logit range와 variance를 모사하기 어렵기 때문이라고 해석하고 이를 시각적으로 확인하고자 그림 5를 얻음.
  • 그림 5를 보면, 큰 model(e.g., ResNet50, VGG13) 일수록 logit이 zero mean에 가깝고 작은 standard deviation를 가짐. 반대로, 작은 model의 경우, zero mean에서 많이 떨어져 있고, 큰 variacne를 가짐.  (resnet56와 resnet110의 경우에는 반대 결과처럼 보이는 데, 왜 그러지?)
  • 따라서, 작은 model인 student가 큰 model인 teacher 만큼 compact logit을 얻는 것은 어려움.
  • 그림 5 b를 통해서, student의 모방 능력을 비교할 수 있음. logit mean과 standard deviation에 대해서 vanilla KD는 teacher와 상당 부분 떨어진 logit을 만들어내는 반면, standardized logit mean과 standard deviation에 대해, 제안 방법은 완전한 일치를 보여줌.

 

 

Conclusions

  • Conventional KD에서 global 하고 shared temperature를 사용하는 이론적 근거가 없었기 때문에, entropy maximization을 사용하여, temperature가 Lagrangian multiplier으로부터 유도됨을 보였고, 이를 통해 constant temperature대신 flexible value를 할당할 수 있음을 증명함. 이를 기반으로 Z-score standardization을 pre-process로 제안하여, teacher가 가지는 본질적인 relation을 집중적으로 학습하게 만들었음.