表格分类的全新方法:我们引入了 TabPFN,一种新的表格数据分类方法,耗时不到 1 秒,却能提供 SOTA(最先进)的性能(与耗时数小时的最佳 AutoML 流水线相比具有竞争力)。
不过,目前它的规模有限:它只能处理多达 1000 个训练样本、100 个特征和 10 个类别的问题。当所有特征都是数值型且没有缺失值时,它的效果最佳(但我们相信,当我们专注于这些情况时,我们也会改善它们在其他情况下的性能)。
TabPFN 与以前的机器学习方法截然不同。它是一种元学习算法,可证明地近似贝叶斯推断,其先验基于因果关系和简约性原理。从定性上看,它产生的预测也非常直观,具有非常平滑的不确定性估计:

TabPFN 恰好是一个 Transformer,但这不是通常的“树与网络”之战b a t t l e。给定一个新的数据集,TabPFN 不使用昂贵/不可靠的基于梯度的训练,面对小型数据集时不会过拟合,并且不需要任何调优。相反,它执行固定网络的一次前向传播:您将训练数据作为集合值输入馈入,以及 $latex x_{test}$;然后网络输出 $latex y_{test}$ 的概率。
TabPFN 经过预训练以近似贝叶斯推断,正如本图所示。
在离线预训练阶段,我们通过从我们对数据集可能样子的先验中采样,生成了数百万个合成数据集,并训练 TabPFN 以对每个数据集中的保留点进行一次前向传播预测。TabPFN 的先验基于结构化因果模型,并通过采样此类模型生成数据,倾向于简约性。对该先验进行贝叶斯推断,将预测整合到结构化因果模型的空间中,并根据数据给定的似然和先验中的概率进行加权——这捕获了数据不同因果解释背后的潜在不确定性。对于新的数据集,一次前向传播即可近似我们先验的贝叶斯推断。
因此,这个 Transformer 学会了在单次前向传播中充当分类算法。归根结底,分类算法对数据进行计算,TabPFN 的前向传播也是如此。因此,TabPFN 实现了一种算法,该算法通过基于梯度的方式进行元学习,以最小化其预测误差。既然这适用于数百万个训练数据集,那么它也适用于测试数据集也就不足为奇了 🙂
想象一下您可以使用 TabPFN 几乎瞬时的预测做些什么。实时机器学习,仅需单个神经网络的一次前向传播,这种操作在不同平台之间具有极强的可移植性,部署起来也非常简单。也可以在智能手机、传感器等设备上使用。加油,#GreenAutoML! (诚然,还有许多其他快速分类器,例如随机森林;TabPFN 提供了与它们相同的速度,同时性能可与当今运行长达一小时的最佳 AutoML 方法相媲美。)
对 Léo Grinsztajn、Edouard Oyallon 和 Gaël Varoquaux 的优秀论文(https://hal.archives-ouvertes.fr/hal-03723551)进行的定性分析表明,TabPFN 在许多方面仍然表现得类似于其他神经网络方法,但性能非常强劲。这可能指向未来有趣的工作,即纳入更多传统上归因于基于树的方法的特征的先验。
TabPFNs 的这项工作为许多人带来了益处:需要快速方法的经典数据科学家、深度学习者、贝叶斯主义者、元学习者等。这可能会形成一个令人兴奋的新社区。
- 经典数据科学家:梦想成真——能够处理小型数据集而无需担心过拟合。到目前为止,小型数据集的最佳答案是随机森林。在我们的实验中,TabPFN 在 179 个数据集上明显优于这些方法(即使对于我们未重点关注的分类数据集也是如此)。另一个优点是性能强劲且执行快速,无需任何调优。具有这些需求愿望的另一种经典方法是梯度提升决策树(XGBoost、CatBoost、LightGBM 等),我们的实验表明 TabPFN 在数值数据集上速度更快且表现出统计学上的显著优势。在分类数据集上,TabPFN 和 XGBoost 表现相当。TabPFN 还具有与众不同的归纳偏置,使其与其他方法不太相关;因此,它可以非常有效地与其他技术进行集成。
局限性?我们对 1000 个训练样本、100 个特征和 10 个类别的限制。并且存在很大的差异;您会发现某些数据集上任何分类方法(包括 TabPFN)都表现不佳。您还会发现 SVM、RF 和 GB 优于 TabPFN 的数据集。我们统计上的显著改进仅适用于跨数据集,而非每个单独的数据集。
- 深度学习者:这是表格数据上许多关于新颖架构、正则化器等最新工作的自然发展。与这些工作不同,我们使用了深度学习中的最新方法,特别是上下文学习,并且我们没有这些方法的潜在缺点(例如为新数据集进行昂贵的神经网络训练或对小型数据集的过拟合)。
- 贝叶斯主义者:另一个梦想成真——像一次前向传播一样快速地计算后验预测分布的近似值。与标准贝叶斯深度学习相比,这也适用于复杂的先验、各种架构、初始权重等。
- 元学习者:最后,一个元学习能够产生最先进性能的应用。虽然整个社区仍然使用在 ImageNet/JFT-300 上预训练的网络而不是 MAML,并且我们仍然都使用 Adam 而不是学习型优化器,但我们令人信服地展示了元学习在实践中的力量。元学习的 TabPFN 似乎比数十年来手动创建的算法更适合处理小型表格数据,因此可以直接应用于实践。
为什么我们目前仅限于小数据?对 1000 个训练数据点的限制是由于标准 Transformer 内存和计算要求与输入长度呈二次关系。目前有大量工作旨在克服这一限制,这也将适用于我们。对 100 个特征和 10 个类别的限制主要是为了保持训练时间合理(1 台带有 8 个 GTX2080 的机器运行 20 小时)。到目前为止,我们一直专注于没有缺失值的数值数据,这就是为什么在具有这些特征的数据集上性能更好的原因。
我们能克服这些限制吗?嗯,今年早些时候我们从 30 个数据点上的平衡二分类(https://arxiv.org/abs/2112.10510)扩展到现在的 1000 个数据点和不平衡数据。我们将继续扩展,但随着我们利用低垂的果实,速度会放缓。
我们预计我们的激进主张会受到最初的质疑。这对科学是健康的!请查看我们的论文了解详细信息,并在可能的情况下找出漏洞。我们开源了所有代码,包括一个 sklearn 接口和一个演示其用法的 Colab 笔记本。我们还有 2 个演示: 一个用于体验 TabPFNs 预测(https://hugging-face.cn/spaces/TabPFN/TabPFNPrediction),另一个用于检查新数据集上的交叉验证 ROC AUC 分数(https://hugging-face.cn/spaces/TabPFN/TabPFNEvaluation)。
这还不是故事的结局。这只是开始。未来有数十种工作可能性,我们很乐意加入合作。如果您认识希望尝试小型表格数据的廉价 SOTA 方法的数据科学家,请广泛分享。重复免责声明:到目前为止,这仅限于不超过 1000 个数据点、100 个特征和 10 个类别的问题,并且在没有缺失特征的数值数据集上表现更好。
此外,如果您想加入我们来扩展这项工作,我们正在 ELLIS 和我们的 ERC 合并拨款项目“深度学习 2.0”的背景下招聘杰出人才,担任博士、博士后和研究工程师职位(请参阅https://automl.org.cn/deep-learning-2-0-extending-the-power-of-deep-learning-to-the-meta-level/)。请在此处申请:https://ml.informatik.uni-freiburg.de/positions/
完整论文:https://arxiv.org/abs/2207.01848
带有 scikit-learn 接口的 Colab 笔记本:https://colab.research.google.com/drive/194mCs6SEPEW6C0rcP7xWzcEtt1RBc8jJ?usp=sharing
我们用于在表格上进行预测和查看表格ROC AUC 的演示(两者都在弱 CPU 上运行,Colab 使用 GPU 会更快)。
