项目介绍
首先是是用MQ2008数据集 和RankNet 网络做的。
RankNet 介绍:
请参考:RankNet 模型 & Pairwise源码解析 - 知乎
数据展示:
MQ2008的数据格式可以参考 百度的一些解释
数据分成了五折交叉验证
主要的模型代码:
import torch
import torch.utils.data as data
import numpy as np
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
y_train = []
x_train = []
query_id = []
array_train_x1 = []
array_train_x0 = []
def extract_features(toks):
# 获取features
features = []
for tok in toks:
features.append(float(tok.s