本文是生命科学AI模型系列教程的一部分,源于阿斯利康(AstraZeneca)与AMD的合作文章,该系列探讨了如何在AMD MI300X上运行与药物发现相关的AI工作负载。本系列的第一篇文章重点介绍了REINVENT4,这是一种用于生成和优化候选分子的分子设计工具。本文特别关注SemlaFlow,一种具有潜在注意力和等变流匹配的高效3D分子生成模型。
基于模拟的药物发现流程是制药行业中用于识别和优化新药物候选物的计算方法。这些流程利用计算机模拟来建模生物系统,并预测潜在药物分子如何与其靶标(如蛋白质或核酸)相互作用。
3D分子生成是一种用于创建分子三维表示的计算技术,在药物发现中扮演着关键角色。在基于模拟的药物发现流程中,3D分子生成与分子建模和虚拟筛选相结合。通过生成3D结构,研究人员可以执行对接模拟,以预测这些分子与靶标蛋白的结合程度。这一步骤对于识别潜在药物候选物至关重要,因为它提供了关于分子结合亲和力和特异性的见解。以三维方式可视化和模拟这些相互作用的能力提高了药物发现过程的准确性和效率。
以往的3D分子生成方法通常面临重大限制,包括采样时间非常慢以及生成的分子化学有效性差,阻碍了它们在药物发现工作流程中的实际应用。
SemlaFlow通过提供最先进的3D分子生成解决方案来应对这些挑战,与现有方法相比,它在采样时间上提供了两个数量级的加速(相当于超过100倍的改进),仅需少至20个采样步骤。这种效率是通过其新颖且可扩展的E(3)-等变Semla架构及其通过等变流匹配进行的训练实现的。SemlaFlow的独特之处在于它能够生成原子类型、坐标、键类型和形式电荷的联合分布,从而无需在生成后推断分布的部分即可提供全面的分子设计。
在本文中,我们展示了如何通过最少的代码更改开始使用,并概述了在AMD硬件上优化SemlaFlow的关键步骤。所使用的环境是具有8个MI300X的TensorWave节点。我们在此的主要关注点是训练。虽然我们使用预测和评估工具来测试训练工具生成的模型,但我们并不专注于优化这些工具。我们的优化工作仅致力于提高3D分子生成模型训练的效率。
SemlaFlow代码
SemlaFlow是一个基于Linux的应用程序,使用Python开发,利用PyTorch实现神经网络。由于PyTorch与ROCm兼容,在AMD GPU上运行SemlaFlow应该很直接。
原始的SemlaFlow仓库包含四个主要脚本:
preprocess- 用于将较大数据集预处理为模型训练使用的内部表示train- 在预处理数据上训练MolFlow模型evaluate- 评估训练好的模型并打印结果predict- 运行训练模型的采样并保存生成的分子
安装
尽管原始的SemlaFlow仓库提供了在Nvidia GPU上运行的说明,但在AMD GPU上运行SemlaFlow相对简单。例如,mamba环境.yml文件指定了CUDA依赖项:pytorch-cuda,可以替换为AMD硬件的等效项rocm-pytorch。其余依赖项可以保持不变,因为它们是与计算无关的包。
要在Kubernetes集群上操作,需要一个可docker化的配方。本博客的其余部分假设环境支持Docker。让我们继续。
Docker化SemlaFlow
使用ROCm和PyTorch安装的最简单方法是使用已安装这些包的基础镜像。Docker Hub上有几个不同版本。所使用的基础镜像是我们在进行实验时最新的版本rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.6.0
使用较早版本的ROCm基础镜像会导致SemlaFlow出现一些问题。升级到带有python3.12的最新镜像解决了该问题。
因此,dockerfile基本上使用最新的ROCm镜像作为基础镜像,并使用conda和pip安装mamba环境yml文件中列出的包。
最后一步是告诉镜像在创建容器时运行什么。我们可以提供如下入口点脚本:
!/bin/bash
SCRIPT="$1"
OUTPUT_FILE="$2"
shift
shift
echo ${OUTPUT_FILE}
检查SCRIPT是否为允许的值之一
if [[ "$SCRIPT" != "preprocess" && "$SCRIPT" != "train" && "$SCRIPT" != "evaluate" && "$SCRIPT" != "predict" ]]; then
echo "Error: SCRIPT must be one of 'preprocess', 'train', 'evaluate', or 'predict'."
exit 1
fi
python -m semlaflow."$SCRIPT" "$@" > /output/${OUTPUT_FILE}
if [[ "$SCRIPT" == "train" ]]; then
将模型检查点复制到输出目录
cp -r lightninglogs/version* /output/
fi
它允许多个参数传递,应该运行哪个脚本、输出文件以及其他运行所选脚本所需的参数。运行脚本所需的参数可以在脚本的参数解析器中找到。
示例命令:
bash entrypoint.sh rocm-semlaflow


