EpiBench: Model Architecture Improvements
This note documents recent improvements to the EpiBench model architecture. Last updated: 2024-11-15
Model Architecture Overview
Based on initial results with the Washington University datasets, I’ve been working on several improvements to the EpiBench model architecture. These changes aim to enhance prediction accuracy, improve interpretability, and optimize computational efficiency.
Limitations of Previous Architecture
The original EpiBench model had several limitations when applied to our new datasets:
- Limited context integration: The model struggled to capture long-range dependencies in the genomic sequence
- Insufficient multi-modal fusion: Histone modification data wasn’t optimally integrated with sequence features
- Computational inefficiency: Training on full-genome datasets was prohibitively expensive
- Interpretability challenges: Attribution of predictions to specific features was limited
New Architecture Components
1. Transformer-Based Sequence Processing
I’ve replaced the original CNN-only approach with a hybrid architecture that includes transformer components:
class SequenceTransformer(nn.Module):
def __init__(self, seq_length=1000, n_heads=8, dim_feedforward=2048, n_layers=4):
super().__init__()
self.embedding = nn.Embedding(5, 256) # ACGTN
self.pos_encoding = PositionalEncoding(256, seq_length)
encoder_layer = nn.TransformerEncoderLayer(
d_model=256,
nhead=n_heads,
dim_feedforward=dim_feedforward
)
self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
def forward(self, x):
x = self.embedding(x)
x = self.pos_encoding(x)
return self.transformer(x)
2. Improved Multi-Modal Fusion
The new architecture implements a more sophisticated fusion mechanism for integrating sequence and epigenetic data:
class MultiModalFusion(nn.Module):
def __init__(self, seq_dim=256, hist_dim=128, fusion_dim=512):
super().__init__()
self.seq_projection = nn.Linear(seq_dim, fusion_dim)
self.hist_projection = nn.Linear(hist_dim, fusion_dim)
self.attention = MultiHeadAttention(fusion_dim, 8)
self.fusion_norm = nn.LayerNorm(fusion_dim)
def forward(self, seq_features, hist_features):
seq_proj = self.seq_projection(seq_features)
hist_proj = self.hist_projection(hist_features)
# Cross-modal attention
fusion = self.attention(seq_proj, hist_proj, hist_proj)
fusion = self.fusion_norm(fusion + seq_proj)
return fusion
3. Region-Specific Prediction Heads
To allow for different prediction strategies in different genomic contexts, I’ve implemented region-specific prediction heads:
- Promoter-specific prediction head
- Enhancer-specific prediction head
- Gene body-specific prediction head
- Intergenic region-specific head
Performance Improvements
Initial benchmarks show significant improvements:
Metric | Original Model | New Architecture | Improvement |
---|---|---|---|
Methylation Prediction MSE | 0.085 | 0.062 | 27% |
AUROC (Classification) | 0.83 | 0.91 | 9.6% |
Training Time (per epoch) | 45 min | 28 min | 38% |
GPU Memory Usage | 22GB | 14GB | 36% |
Interpretability Enhancements
The new architecture incorporates several features to improve interpretability:
- Attention Visualization: Transformer attention maps reveal sequence-epigenetic interactions
- Feature Attribution: Integrated gradients method for attributing predictions to input features
- Region-Specific Explanations: Separate explanations for different genomic contexts
Current Challenges
Several challenges remain to be addressed:
- Further optimization needed for whole-genome scale analysis
- Integration of additional epigenetic features (ATAC-seq, 3D genome)
- Handling of sample heterogeneity in cancer datasets
- Adaptation to limited-data scenarios
Next Development Steps
- Implement transfer learning capabilities to leverage pre-trained models
- Develop specialized architectures for cancer-specific applications
- Create ensemble models that combine predictions from multiple architectures
- Design interpretability tools tailored to biologists without ML expertise
Code Availability
The improved model architecture is available in the dev
branch of our repository. Documentation and examples are being prepared for the next release.
This is a working document that will be updated as model development continues. Related: WashU Data Integration