📊
统一特征抽象
支持 Dense、Sparse、Sequence 三类特征,统一编码、转换与输入接口
pip install nextrec如果需要 WandB 或 SwanLab 实验跟踪,请额外安装 pip install "nextrec[tracking]"。默认安装不会包含这两个可选依赖,以避免部分 Linux 环境安装 wandb 时因缺少 go 编译环境而失败。
如果需要 ONNX 导出或 ONNX Runtime 推理,请额外安装 pip install "nextrec[onnx]"。默认安装不会包含 ONNX 相关依赖,以避免部分 Linux 环境因 onnxruntime 不可用而安装失败。
import pandas as pd
from sklearn.model_selection import train_test_split
from nextrec.basic.features import DenseFeature, SparseFeature
from nextrec.models.ranking.deepfm import DeepFM
df = pd.read_csv("https://raw.githubusercontent.com/zerolovesea/NextRec/main/dataset/movielens_100k.csv")
dense_features = [DenseFeature("age")]
sparse_features = [
SparseFeature("user_id", vocab_size=df["user_id"].max() + 1, embedding_dim=16),
SparseFeature("item_id", vocab_size=df["item_id"].max() + 1, embedding_dim=16),
SparseFeature("gender", vocab_size=df["gender"].max() + 1, embedding_dim=16),
SparseFeature("occupation", vocab_size=df["occupation"].max() + 1, embedding_dim=16),
]
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)
model = DeepFM(
dense_features=dense_features,
sparse_features=sparse_features,
mlp_params={"hidden_dims": [256, 128], "activation": "relu", "dropout": 0.2},
target="label",
device="cpu",
session_id="movielens_deepfm", # 管理实验日志与检查点
)
model.compile(
optimizer="adam",
optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
loss="binary_crossentropy",
)
model.fit(
train_data=train_df,
valid_data=valid_df,
metrics=["auc", "recall", "precision"],
epochs=2,
batch_size=512,
shuffle=True,
)