数据集类别不平衡性对迁移学习的影响分析

Analysis of the effect of class imbalance on transfer learning

  • 摘要: 数据集类别不平衡性是机器学习领域的常见问题,对迁移学习也不例外。本文针对迁移学习下数据集类别不平衡性的影响研究不足的问题,重点研究了以下几种不平衡性处理方法对迁移学习的影响效果分析:过采样、欠采样、加权随机采样、加权交叉熵损失函数、Focal Loss函数和基于元学习的L2RW(Learning to Reweight)算法。其中,前三种方法通过随机采样消除数据集的不平衡性,加权交叉熵损失函数和Focal Loss函数通过调整传统分类算法的损失函数以适应不平衡数据集的训练,L2RW算法则采用元学习机制动态调整样本权重以实现更好的泛化能力。大量实验结果表明,在上述各种不平衡性处理方法中,过采样处理和加权随机采样处理更适合迁移学习。

     

    Abstract: Class imbalance of datasets is a common problem in the field of machine learning and transfer learning is no exception. However, very limited research is available about the effect of class imbalance on transfer learning, this paper focuses on the analysis of the effects of several imbalanced classification algorithms on transfer learning to address the issue: oversampling, undersampling, weighted random sampling, weighted cross entropy loss, Focal Loss and L2RW algorithm based on meta learning. Among them, the first three methods eliminate the imbalance of the dataset by random sampling, weighted cross entropy loss and Focal Loss keep the dataset unchanged and adjust the loss function of standard classification algorithms, and L2RW algorithm adopts meta learning mechanism to adjust the weight of training set sample dynamically to achieve better performance in generalization. Extensive empirical evidence shows that oversampling and weighted random sampling are more suitable for transfer learning among various imbalanced classification algorithms.

     

/

返回文章
返回