Skip to content

NextRec现代推荐系统框架

基于 PyTorch 的统一、高效、可扩展的推荐系统框架

NextRec Logo

安装

bash
pip install nextrec

如果需要 WandB 或 SwanLab 实验跟踪,请额外安装 pip install "nextrec[tracking]"。默认安装不会包含这两个可选依赖,以避免部分 Linux 环境安装 wandb 时因缺少 go 编译环境而失败。

如果需要 ONNX 导出或 ONNX Runtime 推理,请额外安装 pip install "nextrec[onnx]"。默认安装不会包含 ONNX 相关依赖,以避免部分 Linux 环境因 onnxruntime 不可用而安装失败。

快速开始

python
import pandas as pd
from sklearn.model_selection import train_test_split

from nextrec.basic.features import DenseFeature, SparseFeature
from nextrec.models.ranking.deepfm import DeepFM

df = pd.read_csv("https://raw.githubusercontent.com/zerolovesea/NextRec/main/dataset/movielens_100k.csv")

dense_features = [DenseFeature("age")]
sparse_features = [
    SparseFeature("user_id", vocab_size=df["user_id"].max() + 1, embedding_dim=16),
    SparseFeature("item_id", vocab_size=df["item_id"].max() + 1, embedding_dim=16),
    SparseFeature("gender", vocab_size=df["gender"].max() + 1, embedding_dim=16),
    SparseFeature("occupation", vocab_size=df["occupation"].max() + 1, embedding_dim=16),
]

train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)

model = DeepFM(
    dense_features=dense_features,
    sparse_features=sparse_features,
    mlp_params={"hidden_dims": [256, 128], "activation": "relu", "dropout": 0.2},
    target="label",
    device="cpu",
    session_id="movielens_deepfm",   # 管理实验日志与检查点
)

model.compile(
    optimizer="adam",
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
    loss="binary_crossentropy",
)

model.fit(
    train_data=train_df,
    valid_data=valid_df,
    metrics=["auc", "recall", "precision"],
    epochs=2,
    batch_size=512,
    shuffle=True,
)

相关链接

基于 MIT 许可证开源