Deep learning

Class imbalance (class weight, sample weight)

비비이잉 2021. 8. 18. 13:54
반응형

Class Weight

Class weight는 전체 학습 데이터에 대해서 클래스별 가중치를 계산하는 방법으로 같은 클래스 내의 데이터 샘플은 같은 weight를 갖는다. 클래스 A에  해당하는 class weights는 sklearn이 제공하는 compute_class_weight 로 계산할 수 있다.

간단히 이야기하면 클래스별 샘플의 역수가 크 클래스의 weight가 된다. 

 

 

 

이 방법 이외에 

Sample Weight

개념적으로는 class weight와 동일하지만 전체 배치가 아니라 미니 배치 상에서 sample의 수를 고려해서 loss를 계산해주는 방법이다. .

아래의 블로그에서 소개한 방법은 크게 3가지이다

 

1. INS. (Inverse of Number of Samples)

배치 내에서의 sample의 빈도

배치 내의 n번째 샘플이 class weight 

 

2. ISNS (Inverse of Square Root of Nubmer of Samples) 

배치 내에서 sample의 빈도의 루트값 

3. ENS(Effective Number of Samples)

sample의 수 뿐 아니라 data sample의 분포를 함께 고려한 유효 샘플수 (effective number of samples)

어떠한 class weight들을 적용시키지 않고 학습한 모델과 비교하여서 INS 방법을 적용하였을때에 모델의 f0.5 score가 2.5퍼센트 개선된 것을 알 수 있는 결과다. 이 결과가 개선된 것이 majority class에서 일어났는지, minority class에서 일어난 결과인지 비교하기 위해서 INS방법을 적용 시킨 후 percent change를 각 각 클래스 별로 나타낸 결과이다. 

사진에 보면 제일 처음에 있는 category_6부터 오른쪽으로 갈 수록 원래의 클래스 샘플수가 작아지는 방향이다. 

minority class의 변화가 컸었던 것인지, majority class의 변화가 컸었던 것인지 확인하기 위한 그래프이다. 

그래프를 보면 알 수 있듯이 x축에서 오른쪽에 있는 그래프의 recall, precision, f0,5 score값의 change의 변화가 큼을 알 수 있다. 

 

 

 

 

<class-Balanced Loss Based on Effective Number of Samples 논문> 

: class-balanced focal loss

: class-balanced sigmoid cross-entropy loss

: class-balanced softmax cross-entropy loss

 

https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf

https://medium.com/gumgum-tech/handling-class-imbalance-by-introducing-sample-weighting-in-the-loss-function-3bdebd8203b4

 

Handling Class Imbalance by Introducing Sample Weighting in the Loss Function

“Nobody is Perfect” This quote not just applies to us humans but also the data that surrounds us. Any data science practitioner needs to…

medium.com

 

 

반응형

'Deep learning' 카테고리의 다른 글

Keras Tuner 케라스 튜너 설치  (0) 2021.08.23
Optimizer  (0) 2021.08.18
SMOTE for imbalanced image dataset  (0) 2021.08.13
GAN(Generative Adversarial Network)  (0) 2021.08.06
Model Performance Measure  (0) 2021.07.23