📊
统一特征抽象
支持 Dense、Sparse、Sequence 三类特征,统一编码、转换与输入接口
pip install nextrecimport pandas as pd
from sklearn.model_selection import train_test_split
from nextrec.basic.features import DenseFeature, SparseFeature
from nextrec.models.ranking.deepfm import DeepFM
df = pd.read_csv("https://raw.githubusercontent.com/zerolovesea/NextRec/main/dataset/movielens_100k.csv")
dense_features = [DenseFeature("age")]
sparse_features = [
SparseFeature("user_id", vocab_size=df["user_id"].max() + 1, embedding_dim=16),
SparseFeature("item_id", vocab_size=df["item_id"].max() + 1, embedding_dim=16),
SparseFeature("gender", vocab_size=df["gender"].max() + 1, embedding_dim=16),
SparseFeature("occupation", vocab_size=df["occupation"].max() + 1, embedding_dim=16),
]
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)
model = DeepFM(
dense_features=dense_features,
sparse_features=sparse_features,
mlp_params={"hidden_dims": [256, 128], "activation": "relu", "dropout": 0.2},
target="label",
device="cpu",
session_id="movielens_deepfm", # 管理实验日志与检查点
)
model.compile(
optimizer="adam",
optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
loss="binary_crossentropy",
)
model.fit(
train_data=train_df,
valid_data=valid_df,
metrics=["auc", "recall", "precision"],
epochs=2,
batch_size=512,
shuffle=True,
)