Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions app/projects/faformer/page.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Understanding and predicting how protein forms a complex with nucleic acid/prote

Motivated by this, we propose **contact map prediction-based aptamer screening paradigm**. Specifically, as presented in Figure 2(a), our model is trained to identify specific contact pairs between residues and nucleotides when forming a complex. The maximum contact probability across all pairs is then interpreted as the binding affinity, which is subsequently used for aptamer screening.

![Figure 2: (a) The pipeline of contact map prediction between protein and nucleic acid, and applying the predicted results for screening in an unsupervised manner. (b) Comparison between Transformer with vanilla frame averaging framework and FAFormer, where the blue cells indicate FA-related modules. |scale=0.5](./fig/overview.png)
![Figure 2: (a) The pipeline of contact map prediction between protein and nucleic acid, and applying the predicted results for screening in an unsupervised manner. (b) Comparison between Transformer with vanilla frame averaging framework and FAFormer, where the blue cells indicate FA-related modules. |scale=0.8](./fig/overview.png)

Learning E(3) equivariant transformation is the key factor to modeling the protein/nucleic acid 3D structures. In this paper, we propose **FAFormer**, an equivariant Transformer architecture that integrates FA as a geometric module within each layer, as shown in Figure 2(b). FA as a geometric component offers flexibility to effectively integrate geometric information into node representations while preserving the spatial semantics of coordinates and without major modicification on the architectures. FAFormer opens new possibilities for designing equivariant architectures in this domain.

Expand All @@ -33,7 +33,7 @@ Learning E(3) equivariant transformation is the key factor to modeling the prote
Frame averaging (FA) is an encoder-agnostic framework that can make a given encoder equivariant to the Euclidean symmetry group. Specifically, FA proposes to model
the coordinates in eight different frames extracted by PCA, achieving equivariance by averaging the encoded representations, as presented in Figure 3.

![Figure 3: Frame Averaging.|scale=0.5](./fig/fa.png)
![Figure 3: Frame Averaging.|scale=0.3](./fig/fa.png)

You can consider FA as a model "wrapper", where the model architecture doesn't need to be modified but would seperately process 8 inputs. We use $f_{\mathcal{F}}(\mathbf{X})=\{\mathbf{X}^{(g)}\}_{\mathcal{F}}$ to denote the FA operation, where $\mathbf{X}^{(g)}$ is the input in the $g$-th frame. Besides, we use $f_{\mathcal{F}^{-1}}(\{\mathbf{\hat{X}}^{(g)}\}_{\mathcal{F}})=\hat{X}$ to represent the inverse mapping, which is an E(3) equivarnat operation. Note that $\hat{X}^{(g)}$ could be obtained from the encoder. The outcome could be invariant when simply averaging the representations without inverse matrix.

Expand All @@ -46,7 +46,7 @@ where $\mathbf{W}_g\in\Bbb{R}^{3\times D}$. Note that the output of FA Linear mo

### Overall architecture of FAFormer

![Figure 4: Overview of FAFormer architecture. The input consists of the node features, coordinates, and edge representations.|scale=0.5](./fig/faformer.png)
![Figure 4: Overview of FAFormer architecture. The input consists of the node features, coordinates, and edge representations.|scale=0.8](./fig/faformer.png)

As shown in Figure 4(a), the input of FAFormer comprises the node features $\mathbf{Z}\in\Bbb{R}^{N\times D}$, coordinates $\mathbf{X}\in\Bbb{R}^{N\times 3}$, and edge representations $\mathbf{E}\in\Bbb{R}^{N\times K\times D}$ where $K$ is the number of nearest neighbors. Each core modules are dedicatedly integrated with FA, including

Expand All @@ -61,11 +61,11 @@ As shown in Figure 4(a), the input of FAFormer comprises the node features $\mat

This task aims to predict the exact contact pairs between protein and protein/nucleic acids, which conducts binary classification over all pairs. This task is challenge due to the sparsity of the contact pairs. We compare FAFormer with six state-of-the-art models, and the results are presented in Table 1.

![Figure 5: Contact Map Prediction.|scale=0.5](./fig/exp1.png)
![Figure 5: Contact Map Prediction.|scale=0.7](./fig/exp1.png)

### Unsupervised Aptamer Screening

This task aims to screen the positive aptamers from a large number of candidates for a given protein target. We quantify the binding affinities between RNA and the protein target as the highest contact probability among the residue-nucleotide pairs. The models are first trained on the protein-RNA complexes training set using the contact map prediction, then the aptamer candidates are ranked based on the calculated highest contact probabilities. Top-10 precision, Top-50 precision, and PRAUC are used as the metrics.

![Figure 6: Aptamer screening.|scale=0.5](./fig/exp2.png)
![Figure 6: Aptamer screening.|scale=0.7](./fig/exp2.png)

Binary file added app/projects/heart/fig/FM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/heart/fig/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/heart/fig/ehr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/heart/fig/exp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/heart/fig/objective.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 75 additions & 0 deletions app/projects/heart/page.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import { Authors, Badges } from '@/components/utils'

# HEART: Learning Better Representation of EHR Data with a Heterogeneous Relation-Aware Transformer

<Authors
authors="Tinglin Huang, Yale University; Syed Asad Rizvi, Yale University; Rohan Krishna Thakur, Yale University; Vimig Socrates, Yale University; Meili Gupta, Yale University; David van Dijk, Yale University; R. Andrew Taylor, Yale University; Rex Ying, Yale University"
/>

<Badges
venue="Journal of Biomedical Informatics 159 (2024): 104741"
github="https://github.com/Graph-and-Geometric-Learning/HEART"
paper="https://www.sciencedirect.com/science/article/abs/pii/S153204642400159X"
/>


## Introduction
Electronic health records (EHRs) is a tabular data which digitizes the medical information of an encounter, such as demography, diagnosis, medication, lab results, procedures, as shown in Figure 1:

![Figure 1: Illustration of EHRs.|scale=0.3](./fig/ehr.png)

Many researches focus on distilling meaningful clinical information from cohorts with **foundation model**. Specifically, such models treat medical entities in EHRs as tokens and organize the entities included in the encounters as sentences. These “sentences” can then be encoded by a transformer, allowing the entities to be represented in an embedding space, as shown in Figure 2(a):

![Figure 2: Comparison between current foundation model and ours.|scale=0.8](./fig/FM.png)

However, we argue that the heterogeneous correlations between medical entities are critical for representation but have largely been overlooked. For example, understanding the relationship between "Antibiotics" (medication) and both "Fever" (diagnosis) and "Antibody Tests: Positive" (lab test) enables the model to recommend more clinically plausible drugs.

Motivated by this, we propose **HEART**, a Heterogeneous Relation-Aware Transformer for EHR data, which explicitly parameterizes pairwise representations between entities heterogeneously. Additionally, we introduce a multi-level attention mechanism to mitigate the computational cost associated with multiple visits, as demonstrated in Figure 2(b). Finally, two dedicated pretraining objectives are applied to enhance the model during pretraining.


## Method

### Heterogeneous Relation Embedding & Multi-level Attention Scheme

Given a patient, we flatten the corresponding historical visits into several sequences of entities:
$$
[[D_1, V_{1,1},\cdots,V_{1,N_1}],\cdots,[D_S, V_{S,1},\cdots,V_{S,N_S}]]
$$
where $S$ is the number of visits, $N_i$ is the number of entities in the $i$-th visit, and $D_i$ represents the demography token for the patient in the $i$-th visit. A learnable embedding will be assigned to each entity. Besides the entity embeddings, we explicitly encode the pairwise representation for each entity pair. Specifically, for an entity pair in the same visit $(V_n,V_m)$, we calculate the pairwise embedding $\textbf{R}_{n\leftarrow m}$ as follow:
$$
\textbf{R}_{n}=\text{Linear}_{\tau(V_n)}(\textbf{V}_{n}), \textbf{R}_{m}=\text{Linear}_{\tau(V_m)}(\textbf{V}_{m});\\
\textbf{R}_{n\leftarrow m}=\text{Linear}(\textbf{R}_{n}||\textbf{R}_{m})
$$
where $\text{Linear}_{\tau(\cdot)}$ denotes a type-specific linear transformation. This encoding will operate on each pair of entities in the same visit.

Computation cost will be the one of the biggest challenge to encode these heterogeneous representations. To alleviate this, we implemented a hierarchical encoding scheme to
combine the encounter-level and entity-level attentions, as shown in Figure 3:

![Figure 3: Frame Averaging.|scale=0.8](./fig/arch.png)

Specifically, as for entity-level context, we conduct attention among the entities within the same visit:
$$
[\mathbf{D}',\mathbf{V}_1',\cdots,\mathbf{V}_N']=\text{Entity-Attn}([\mathbf{D}',\mathbf{V}_1',\cdots,\mathbf{V}_N'])
$$
Besides, the heterogeneous relation will be introduced as a bias term to refine the attention map and the context to update the entity embeddings. As for the encounter-level context, we limit the attention to demography tokens across all historical encounters:
$$
[\mathbf{D}_1',\cdots,\mathbf{D}_S']=\text{Enc-Attn}([\mathbf{D}_1,\cdots,\mathbf{D}_S])
$$


### Pretrained Objective

Most previous approaches adopt masked token prediction (MTP) for pretraining, which replaces actual tokens with [MASK] and performs single-label classification at each masked position. However, MTP is position-dependent and thus not suitable for EHR due to the unordered nature of medical entities. In light of this, we adapt MTP to the missing entity prediction (MEP) task, which is position-agnostic and heterogeneity-aware. The main idea is to let the model perform multi-label classification based on one [MASK] for each entity type, as shown in Figure 4.

![Figure 4: Comparison between masked token prediction and missing entities prediction.|scale=0.7](./fig/objective.png)

Besides, we also incorporate anomaly detection as an additional pretraining task to encourage the model to identify unrelated entities given a context and to learn more robust representations. Specifically, we replace some of the entities with random entities with the same type to synthesize anomaly data. A binary classifier is applied to predict whether it is an anomaly.


## Downstream tasks

We evaluate HEART across 5 downstream tasks on 2 EHR datasets:
* Dataset: [MIMIC-III](https://mimic.mit.edu/docs/iii/) and [eICU](https://eicu-crd.mit.edu/about/eicu/).
* Downstream task: death prediction, prolonged length of stay (PLOS) prediction, readmission prediction, and next diagnosis prediction in 6/12 months.

![Figure 5: Benchmarking.|scale=0.7](./fig/exp.png)
Binary file added app/projects/molgroup/fig/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/bi-level.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/fingerprint.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/grouping.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/study.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/study2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added app/projects/molgroup/fig/study3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
103 changes: 103 additions & 0 deletions app/projects/molgroup/page.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import { Authors, Badges } from '@/components/utils'

# Learning to Group Auxiliary Datasets for Molecule

<Authors
authors="Tinglin Huang, Yale University; Ziniu Hu, University of California, Los Angeles; Rex Ying, Yale University"
/>

<Badges
venue="NeurIPS 2023"
github="https://github.com/Graph-and-Geometric-Learning/MolGroup"
arxiv="https://arxiv.org/abs/2307.04052"
pdf="https://arxiv.org/pdf/2307.04052"
/>


## Introduction

Machine learning has been serving as a powerful tool in the molecular property prediction, such as estimating the toxicity of new clinical drugs and characterizing the binding results for the inhibitors. However, labeling molecules requires expensive real-world clinical trials and expert knowledge, resulting in limited dataset size. For example, ClinTox which predicts clinical toxicity of new drugs only includes 1,000 labeled instances and FreeSolv which estimates the hydration free energy of molecule in water only contains 400 labeled instances.

In this paper, we explore a strategy that is to collaborate with **auxiliary datasets**. Specifically, as shown in the Figure 1, we augment the target dataset with the other auxiliary datasets, which are collected from different sources and have different properties. Our assumption is that introducing out-of-distribution knowledge can help the model to learn more robust and generalizable representations.

![Figure 1: Illustration of dataset grouping.|scale=0.8](./fig/grouping.png)

The key challenge is that having more data doesn't always guarantee improvements: negative transfer can occur! We first do an emperical study which combines all the possible pairs of datasets and evaluate the performance improvement on the target dataset. As shown in the Figure 2, the results show that the performance varies significantly across different pairs, which demonstrates the affinity distribution among different datasets.

![Figure 2: Relative improvement of using the combination of target dataset and auxiliary dataset over only using target dataset: (perf(a, b) − perf(a))/perf(a), where a is target dataset and b is auxiliary dataset.|scale=0.8](./fig/study.png)

In this paper, we focus on designing a dataset grouping method for molecules and propose MolGroup, a routing-based molecule grouping method. MolGroup involves calculating the affinity scores of each auxiliary dataset based on the graph structure and task information, and selecting the auxiliary datasets with high affinity.

## Emperical Study

To have a deeper understanding of the dataset affinity, we analyze the relationship between datasets by dividing them into two dimensions:
* Structural characteristics: **Fingerprint features**
* Associated predictive task: **Task embedding** extracted by Task2vec

![Figure 3: Structural and task distribution extraction.|scale=0.7](./fig/fingerprint.png)

For each pair of dataset, we can measure the similarity in terms of structure and task by using asymmetric KL divergence as similarity metric.

![Figure 4: Structural and task similarity measurement.|scale=0.5](./fig/study2.png)

Based on this, we plot the regression curve and calculate the Pearson correlation between relative improvement and structural/task similarity over all the combination pairs in Figure 5(a). Additionally, we compute the structural and task similarity between each target dataset and the other 14 datasets individually. We then calculate the Pearson correlation between the similarity scores and the corresponding relative improvement for each target dataset, as presented in Figure 5(b) and (c).

![Figure 5: (a) Regression curves between relative improvement and the measures of structure similarity, task similarity, and their mixing. (b,c) Pearson correlation between relative improvement and similarity of fingerprint distribution/task embedding of each molecule dataset individually.|scale=0.8](./fig/study3.png)

We can have three findings:
* Combination of task and structure leads to stronger correlation.
* Both similar and dissimilar structures and tasks can benefit target dataset
* Structure and task are compensatory

## Method

![Figure 5: Structural and task similarity measurement.|scale=0.8](./fig/arch.png)

### Routing Mechanism

We apply routing mechanism to quantify the affinity between datasets. The inituition is that parameters of a positively transferred pair should be more “shared” while parameters of a negatively transferred pair should be more “specific”. Based on this, the routing function $g(\cdot)$ is introduced in each layer as follows:
$$
\mathbf{z}_m^{(l+1)}=\alpha_m f_{\theta_T}^{(l)}(\mathbf{z}_m^{(l)})+(1-\alpha_m)f_{\theta_m}^{(l)}(\mathbf{z}_m^{(l)}),\\
\text{with } \alpha_m=g_m(\mathcal{B}_T,\mathcal{B}_m)
$$
where $T$ denotes the target dataset, $m$ denotes the $m$-th auxiliary dataset, $\mathbf{z}_m^{(l)}$ is the output of the $l$-th layer of the $m$-th dataset, $\mathcal{B}_T$ and $\mathcal{B}_m$ are the batch of target dataset and auxiliary dataset, respectively. $f_{\theta_T}^{(l)}$ and $f_{\theta_m}^{(l)}$ are the neural networks with parameters $\theta_T$ and $\theta_m$.

Inspired by the emperical study, we calculate the gating score by combining the structural and task similarity:
* For task affinity, we assign learnable embeddings for each dataset $e^{\text{task}}_T,e^{\text{task}}_m$
* For structure affinity, we embeded the fingerprint features of each molecule and apply Set2Set to obtain $e^{\text{task}}_T,e^{\text{task}}_m$
Finally, the gating score is computed as:
$$
\alpha_m=g_m(\mathcal{B}_T,\mathcal{B}_m)=\sigma(\lambda e^{\text{task}}_T\cdot e^{\text{task}}_m+(1-\lambda)e^{\text{struc}}_T\cdot e^{\text{struc}}_m)
$$

### Bi-level Optimization Framework

What is the desired objective for auxiliary dataset grouping? We hope that it can optimize the routing mechanism toward minimizing target dataset’s loss. Besides, the auxiliary dataset with great benefit should be assigned a larger gating score. At each learning step, we explicitly represent target parameter optimized by $m$-th auxiliary dataset as $\theta_T(\alpha_m)$. Then, to achieve the above goal, let’s break the optimization into two steps to meet the objective:
* Lower level: Optimize the target parameter $\theta_T$ with $\mathcal{B}_m$ and obtain $\theta_T(\alpha_m)$
* $\theta_T(\alpha_m)=\argmin_{\theta_T}L_m(\mathcal{B}_T;\theta_T,\theta_m,\alpha_m)$
* Higher level: Optimize $g(\cdot)$ based on the loss of $\theta_T(\alpha_m)$ on $\mathcal{B}_T$
* $\min_{\alpha_m}L_T(\mathcal{B}_T;\theta_T(\alpha_m))$

### Overall pipeline

* Step1: initialize and train the model with routing function on all the datasets
1) Update target parameters through routing mechanism with auxiliary dataset
2) Optimize routing mechanism with updated target parameters
* Step2: filter out the datasets that contains gate scores above $\beta$
* Step3: go to step4 if iteration number == $n$ or go back to step1
* Step4: pick the auxiliary datasets with topk affinity

## Results

We evaluate the model using 15 molecule datasets and consider 11 small datasets as the target datasets. We report the performance of the the model with the selected auxiliary datasets using GIN and pretrained Graphormer as the backbone. The results are shown in Figure 6. We can see that MolGroup outperforms the baseline methods, which demonstrates the effectiveness of MolGroup in selecting auxiliary datasets.

![Figure 6: Structural and task similarity measurement.|scale=0.8](./fig/results.png)

Figure 7 visulizes the selected auxiliary datasets for each target dataset, where each edge from auxiliary dataset to target dataset represents a selection. We can have following observations:
* PCBA is the most “famous” one, which can generally benefit most of the target datasets
* Tox21 can benefit ClinTox and ToxCast which are all related to toxicity prediction
* Some datasets belong to distinct domains but still can benefit the other dataset
* Qm8 for BBBP
* ESOL for Lipo

![Figure 7: Selected grouping for each target dataset.|scale=0.5](./fig/graph.png)
Loading