Deep Contextual Clinical Prediction with Reverse Distillation

Abstract

Healthcare providers are increasingly using machine learning to predict patient outcomes to make meaningful interventions. However, despite innovations in this area, deep learning models often struggle to match performance of shallow linear models in predicting these outcomes, making it difficult to leverage such techniques in practice. In this work, motivated by the task of clinical prediction from insurance claims, we present a new technique called Reverse Distillation which pretrains deep models by using high-performing linear models for initialization. We make use of the longitudinal structure of insurance claims datasets to develop Self Attention with Reverse Distillation, or SARD, an architecture that utilizes a combination of contextual embedding, temporal embedding and self-attention mechanisms and most critically is trained via reverse distillation. SARD outperforms state-of-the-art methods on multiple clinical prediction outcomes, with ablation studies revealing that reverse distillation is a primary driver of these improvements.

Publication
Proceedings of the Thirty-Fifth AAAI Conference on Artificial Intelligence
Rohan Kodialam
Rohan Kodialam
Master’s student

Citadel

Rebecca Peyser Boiarsky
Rebecca Peyser Boiarsky
PhD student

Rebecca’s research interests include developing methods to learn disease subtypes and disease progression models for precision medicine applications. She is particularly interested in leveraging machine learning algorithms together with bioinformatics to better understand disease.

Justin Lim
Justin Lim
Master’s student

Justin works on interpretable and robust algorithms to learn treatment policies and understand their variation in practice.

David Sontag
David Sontag
Associate Professor of EECS

My research focuses on advancing machine learning and artificial intelligence, and using these to transform health care.

Related