失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 使用图神经网络预测药物-药物相互作用

使用图神经网络预测药物-药物相互作用

时间:2023-04-24 21:56:50

相关推荐

使用图神经网络预测药物-药物相互作用

使用图神经网络预测药物-药物相互作用

了解药物如何相互作用是医学研究和实践的关键问题。图形机器学习(GraphML)领域可用于以高可信度回答有关药物 - 药物相互作用的问题。本次学习通过GraphSAGE图卷积来预测DDI(药物-药物相互作用)。

使用图神经网络预测药物-药物相互作用

使用图神经网络预测药物-药物相互作用1.数据2.定义词汇意义3.学习目标4.加载数据集5.消息传递概述6.数据集拆分7.模型构建GraphSAGE构建链接预测器 8.训练

1.数据

数据来源于Open Graph Benchmark (OGB) 是用于图形机器学习任务的开源基准数据集的集合。我们在本文中的重点是ogbl-ddi数据集,如上所述,它由单个药物 - 药物相互作用(DDI)网络组成。

我们在数学上将 DDI 图定义为 G = (V, E),其中 V 是节点集,E 是边的集合。图中的每个节点 v ∈ V 代表 FDA 批准或实验药物。两个节点u和v之间存在边缘(u,v)表明两种药物相互作用,使得同时服用两种药物的效果与药物彼此独立作用的预期效果有很大不同[1]。例如,靶向相同蛋白质的两种药物可能具有显着的相互作用。

2.定义词汇意义

首先,一些词汇,‘正边’是数据集中存在的边。每个正边代表由边缘端点表示的两种药物之间的已知显着相互作用。

如果两种药物不相互作用(即,它们一起服用与单独服用时具有相同的效果),则图表中不会存在边缘。这些“单独的节点”被称为负边;换句话说,就是图中不存在的边。

3.学习目标

本次学习的目标是开发一个图形机器学习模型来解决链接预测任务:给定两个药物作为输入,我们希望预测两种药物是否相互作用,即图中的这两个节点之间是否应该存在边缘。这应该允许我们通过将缺失的边缘理解为正边或负边来完成数据集。

4.加载数据集

按照 OGB 网站的示例,我们可以将 DDI 数据集加载到 PyTorch Geometric (PyG) 中:

from ogb.linkproppred import PygLinkPropPredDatasetdataset_name = 'ogbl-ddi'dataset = PygLinkPropPredDataset(name=dataset_name)print(f'The {dataset_name} dataset has {len(dataset)} graph(s).')ddi_G = dataset[0]print(f'DDI 图: {ddi_G}')print(f'节点数量 |V|: {ddi_G.num_nodes}')print(f'边的数量 |E|: {ddi_G.num_edges}')print(f'无向图?: {ddi_G.is_undirected()}')print(f'节点平均度: {ddi_G.num_edges / ddi_G.num_nodes:.2f}')print(f'节点特征: {ddi_G.num_node_features}')print(f'有孤立点?: {ddi_G.has_isolated_nodes()}')print(f'有自循环?: {ddi_G.has_self_loops()}')

输出Using backend: pytorchDownloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zipDownloaded 0.04 GB: 100%|██████████████████████████████████████████████████████████████| 46/46 [00:31<00:00, 1.45it/s]Extracting dataset\ddi.zipProcessing...Loading necessary files...This might take a while.Processing graphs...100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 58.81it/s]Converting graphs into PyG objects...100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]Saving...The ogbl-ddi dataset has 1 graph(s).Done!DDI 图: Data(num_nodes=4267, edge_index=[2, 2135822])节点数量 |V|: 4267边的数量 |E|: 2135822无向图?: True节点平均度: 500.54节点特征: 0有孤立点?: False有自循环?: False

注意:DDI数据集中,图没有节点特征,后续需要进行加工处理。

5.消息传递概述

每个 GNN“层”由在每个节点与其邻居之间传递的一轮此消息、每个节点接收的邻居消息的聚合以及使用聚合消息计算更新的嵌入来定义的。

消息传递是节点能够合并来自其局部邻域结构的信息以确定其自身嵌入的机制。它由生成消息的每个节点组成,该消息在回合中沿该节点的传出边缘传递给其他节点。

消息传递的PyG实现,以及其他知识可以看我这一篇PyG消息传递

6.数据集拆分

OGB 为我们提供了数据集拆分。正边是图中存在的边:集合 {(u,v)∈E} ,其中 u,v∈V ,负边是图中不存在的边:集合 {(u,v)∉E} 。 我们需要正边和负边来训练我们的链接预测模型。 数据集的拆分详情可以查看ogb官网,

split_edges = dataset.get_edge_split()train_edges, valid_edges, test_edges = split_edges['train'], split_edges['valid'], split_edges['test']print(train_edges)print(f'训练集正边的数量: {train_edges["edge"].shape[0]}')print(f'验证集正边的数量: {valid_edges["edge"].shape[0]}')print(f'验证集负边的数量: {valid_edges["edge_neg"].shape[0]}')print(f'测试集正边的数量: {test_edges["edge"].shape[0]}')print(f'测试集负边的数量: {valid_edges["edge_neg"].shape[0]}')

输出{'edge': tensor([[4039, 2424],[4039, 225],[4039, 3901],...,[ 647, 708],[ 708, 338],[ 835, 3554]])}训练集正边的数量: 1067911验证集正边的数量: 133489验证集负边的数量: 101882测试集正边的数量: 133489测试集负边的数量: 101882

ddi_graph.edge_index 的形状为 [2, 2 * E] 因为我们在 GNN 中使用 edge_index,这需要从 u 向 v 和 v 向 u 发送信息(因为我们将在下面看到)。药物1对药物2有作用,相反药物2对药物1也有作用,作用是相互的。

7.模型构建

模型有两个部分: 1)图神经网络生成节点嵌入 2) 输出链接预测概率的深度神经网络

GraphSAGE构建

import torchimport torch_geometricimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.loader import DataLoaderfrom torch_geometric.nn import SAGEConvfrom torch_geometric.utils import negative_samplingfrom tqdm import trangeclass GraphSAGE(torch.nn.Module):"""使用 GraphSAGE 架构构建的图神经网络。"""def __init__(self, conv, in_channels, hidden_channels, out_channels, num_layers, dropout):'''in_channels:初始节点嵌入的维度。由于药物没有节点特征,我们将随机初始化这些向量。hidden_channels:中间节点嵌入的维度。隐藏层的维度。out_channels:输出节点嵌入的维度。num_layers:我们的 GNN 中的层数K。这是应用 GraphSAGE 运算符的次数。dropout:Dropout 应用于权重矩阵 W1 和 W2。'''super(GraphSAGE, self).__init__()self.convs = torch.nn.ModuleList()assert (num_layers >= 2), 'Have at least 2 layers'##至少两层卷积# 在每一个layer中增加conv,上一层与下一层的维度必须一致# 我们还应用了归一化,之后输出节点嵌入。每个卷积层都是 L2 归一化的。self.convs.append(conv(in_channels, hidden_channels, normalize=True))for l in range(num_layers - 2):self.convs.append(conv(hidden_channels, hidden_channels, normalize=True))self.convs.append(conv(hidden_channels, out_channels, normalize=True))self.num_layers = num_layersself.dropout = dropoutdef forward(self, x, edge_index, edge_attr):if edge_attr is not None: ## 如果有edge_attrreturn self.forward_with_edge_attr(x, edge_index, edge_attr)# x 是初始节点嵌入的矩阵,形状 [N, in_channels]for i in range(self.num_layers - 1):# 第 i 层进行消息传递和聚合x = self.convs[i](x, edge_index)# x 的形状为 [N, hidden_channels]# 通过非线性激活函数relux = F.relu(x)x = F.dropout(x, p=self.dropout, training=self.training)# 生成最终嵌入, x 的形状为 [N, out_channels]x = self.convs[self.num_layers - 1](x, edge_index)return xdef forward_with_edge_attr(self, x, edge_index, edge_attr):# x 是初始节点嵌入的矩阵,形状 [N, in_channels]for i in range(self.num_layers - 1):# 第 i 层进行消息传递和聚合x = self.convs[i](x, edge_index, edge_attr)# x 的形状为 [N, hidden_channels]# 通过非线性激活函数relux = F.relu(x)x = F.dropout(x, p=self.dropout,training=self.training)# 生成最终嵌入, x 的形状为 [N, out_channels]x = self.convs[self.num_layers - 1](x, edge_index, edge_attr)return x

#设置参数graphsage_in_channels = 256 graphsage_hidden_channels = 256 graphsage_out_channels = 256 graphsage_num_layers = 2 dropout = 0.5 ###注意,因为数据库ddi本身没有附带节点特征矩阵,所以我们要创立初始嵌入。torch.nn.Embeddinginitial_node_embeddings = torch.nn.Embedding(ddi_graph.num_nodes, graphsage_in_channels).to(device)##图节点特征向量形状为[N,in_channels]initial_node_embeddings

输出Embedding(4267, 256)

## 实例化模型GraphSAGEgraphsage_model = GraphSAGE(SAGEConv, graphsage_in_channels, graphsage_hidden_channels,graphsage_out_channels,graphsage_num_layers, dropout).to(device)

链接预测器

link_predictor_in_channels = graphsage_out_channelslink_predictor_hidden_channels = link_predictor_in_channelsclass LinkPredictor(torch.nn.Module):"""将两个输入转换为单个输出的通用网络。"""def __init__(self, in_channels, hidden_channels, dropout, out_channels=1,concat=lambda x, y: x * y):super(LinkPredictor, self).__init__()self.model = nn.Sequential(nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(hidden_channels, out_channels), nn.Sigmoid())self.concat = concatdef forward(self, u, v):x = self.concat(u, v)return self.model(x)link_predictor = LinkPredictor(in_channels=link_predictor_in_channels, hidden_channels=link_predictor_hidden_channels, dropout=dropout).to(device)

8.训练

##训练我们的完整模型(GraphSAGE + LinkPredictor)def train(graphsage_model, link_predictor, initial_node_embeddings, edge_index, pos_train_edges, optimizer, batch_size, edge_attr=None):total_loss, total_examples = 0, 0# 设置我们的模型进行训练graphsage_model.train()link_predictor.train()# 迭代成批的训练边(“正边”)# (最后一次迭代的边数可能比 batch_size 少)for pos_samples in DataLoader(pos_train_edges, batch_size, shuffle=True):optimizer.zero_grad()# 运行 GraphSAGE 前向传递node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)#对由 attr:'edge_index'给出的图的随机对负边进行采样。# neg_samples 是一个尺寸为 [2, batch_size] 的张量neg_samples = negative_sampling(edge_index, num_nodes=initial_node_embeddings.size(0),num_neg_samples=len(pos_samples),method='dense')# 在正边嵌入上运行链接预测器前向传递pos_preds = link_predictor(node_embeddings[pos_samples[:, 0]], node_embeddings[pos_samples[:, 1]])# 在负边嵌入上运行链接预测器前向传递neg_preds = link_predictor(node_embeddings[neg_samples[0]], node_embeddings[neg_samples[1]])preds = torch.concat((pos_preds, neg_preds))labels = torch.concat((torch.ones_like(pos_preds), torch.zeros_like(neg_preds)))loss = F.binary_cross_entropy(preds, labels)loss.backward()optimizer.step()num_examples = len(pos_preds)total_loss += loss.item() * num_examplestotal_examples += num_examplesreturn total_loss / total_examples

##参数lr = 0.005 batch_size = 65536 epochs = 2 eval_steps = 5 optimizer = torch.optim.Adam(list(graphsage_model.parameters()) + list(link_predictor.parameters()),lr=lr)

我们根据 OGB 数据集中提供的验证和测试正负边缘来评估我们的模型:

pos_valid_edges = valid_edges['edge'].to(device)neg_valid_edges = valid_edges['edge_neg'].to(device)pos_test_edges = test_edges['edge'].to(device)neg_test_edges = test_edges['edge_neg'].to(device)from ogb.linkproppred import Evaluatorevaluator = Evaluator(name = dataset_name)

@torch.no_grad()def test(graphsage_model, link_predictor, initial_node_embeddings, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator, edge_attr=None):graphsage_model.eval()link_predictor.eval()final_node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)pos_valid_preds = []for pos_samples in DataLoader(pos_valid_edges, batch_size):pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], final_node_embeddings[pos_samples[:, 1]])pos_valid_preds.append(pos_preds.squeeze())pos_valid_pred = torch.cat(pos_valid_preds, dim=0)neg_valid_preds = []for neg_samples in DataLoader(neg_valid_edges, batch_size):neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], final_node_embeddings[neg_samples[:, 1]])neg_valid_preds.append(neg_preds.squeeze())neg_valid_pred = torch.cat(neg_valid_preds, dim=0)pos_test_preds = []for pos_samples in DataLoader(pos_test_edges, batch_size):pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], final_node_embeddings[pos_samples[:, 1]])pos_test_preds.append(pos_preds.squeeze())pos_test_pred = torch.cat(pos_test_preds, dim=0)neg_test_preds = []for neg_samples in DataLoader(neg_test_edges, batch_size):neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], final_node_embeddings[neg_samples[:, 1]])neg_test_preds.append(neg_preds.squeeze())neg_test_pred = torch.cat(neg_test_preds, dim=0)# Calculate Hits@20evaluator.K = 20valid_hits = evaluator.eval({'y_pred_pos': pos_valid_pred, 'y_pred_neg': neg_valid_pred})test_hits = evaluator.eval({'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred})return valid_hits, test_hits

import matplotlib.pyplot as pltepochs_bar = trange(1, epochs + 1, desc='Loss n/a')edge_index = ddi_graph.edge_index.to(device)pos_train_edges = train_edges['edge'].to(device)losses = []valid_hits_list = []test_hits_list = []for epoch in epochs_bar:loss = train(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_train_edges, optimizer, batch_size)losses.append(loss)epochs_bar.set_description(f'Loss {loss:0.4f}')if epoch % eval_steps == 0:valid_hits, test_hits = test(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator)print()print(f'Epoch: {epoch}, Validation Hits@20: {valid_hits["hits@20"]:0.4f}, Test Hits@20: {test_hits["hits@20"]:0.4f}')valid_hits_list.append(valid_hits['hits@20'])test_hits_list.append(test_hits['hits@20'])else:valid_hits_list.append(valid_hits_list[-1] if valid_hits_list else 0)test_hits_list.append(test_hits_list[-1] if test_hits_list else 0)plt.title(dataset.name + ": GraphSAGE")plt.xlabel("Epoch")plt.plot(losses, label="Training loss")plt.plot(valid_hits_list, label="Validation Hits@20")plt.plot(test_hits_list, label="Test Hits@20")plt.legend()plt.show()

输出Loss 0.4814: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:27<00:00, 13.65s/it]

我自己电脑太拉了就跑了2个epoch,看看能不能跑通代码试一下。

白嫖gpu跑了一下50epoch

如果觉得《使用图神经网络预测药物-药物相互作用》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。