Skip to content

损失函数

NextRec支持10余种不同任务类型的损失函数。包括pointwise,pairwise,listwise。具体支持的损失函数如下:

Pointwise 损失函数

损失函数名称别名描述适用任务
bcebinary_crossentropy二元交叉熵损失,用于二分类任务二分类
weighted_bce-加权二元交叉熵损失,支持样本权重二分类
focalfocal_loss焦点损失,用于处理类别不平衡问题二分类
cb_focalclass_balanced_focal类别平衡焦点损失,需要 class_counts 参数多分类/二分类
crossentropyce交叉熵损失,用于多分类任务多分类
mse-均方误差损失,用于回归任务回归
mae-平均绝对误差损失,用于回归任务回归

Pairwise 损失函数

损失函数名称描述适用任务
bpr贝叶斯个性化排序 (Bayesian Personalized Ranking) 损失排序
hingeHinge 损失 (SVM 风格)排序
triplet三元组损失,用于学习 item 嵌入表示学习

Listwise 损失函数

损失函数名称描述适用任务
sampled_softmax / softmax采样 Softmax 损失,用于大规模排序排序
infonceInfoNCE 损失,对比学习常用表示学习
listnetListNet 损失,基于列表的排序排序
listmleListMLE 损失,最大似然估计方法排序
approx_ndcg近似 NDCG 损失,直接优化 NDCG 指标排序

使用示例

在基类模型BaseModelcompile参数,通过修改该方法中的loss以及loss_params参数来配置损失函数。示例代码:

python
model.compile(
    optimizer="adam",
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
    loss="binary_crossentropy", # 设置损失函数
)

NextRec CLI 集成

在命令行工具NextRec CLI中,通过修改训练配置文件中的loss参数来进行调整。当多任务时,需要依次设置不同任务的损失函数。

yaml
train:
  loss:
    - 'bce'
    - 'focal'
  loss_params:
    - {} 
    - alpha: 0.8
      gamma: 2.0

下一步

基于 MIT 许可证开源