EpiBench: Model Architecture Improvements

Neural Network Architecture Diagram

This note documents recent improvements to the EpiBench model architecture. Last updated: 2024-11-15

Model Architecture Overview

Based on initial results with the Washington University datasets, I’ve been working on several improvements to the EpiBench model architecture. These changes aim to enhance prediction accuracy, improve interpretability, and optimize computational efficiency.

Limitations of Previous Architecture

The original EpiBench model had several limitations when applied to our new datasets:

  1. Limited context integration: The model struggled to capture long-range dependencies in the genomic sequence
  2. Insufficient multi-modal fusion: Histone modification data wasn’t optimally integrated with sequence features
  3. Computational inefficiency: Training on full-genome datasets was prohibitively expensive
  4. Interpretability challenges: Attribution of predictions to specific features was limited

New Architecture Components

1. Transformer-Based Sequence Processing

I’ve replaced the original CNN-only approach with a hybrid architecture that includes transformer components:

class SequenceTransformer(nn.Module):
    def __init__(self, seq_length=1000, n_heads=8, dim_feedforward=2048, n_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(5, 256)  # ACGTN
        self.pos_encoding = PositionalEncoding(256, seq_length)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=256, 
            nhead=n_heads,
            dim_feedforward=dim_feedforward
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        return self.transformer(x)

2. Improved Multi-Modal Fusion

The new architecture implements a more sophisticated fusion mechanism for integrating sequence and epigenetic data:

class MultiModalFusion(nn.Module):
    def __init__(self, seq_dim=256, hist_dim=128, fusion_dim=512):
        super().__init__()
        self.seq_projection = nn.Linear(seq_dim, fusion_dim)
        self.hist_projection = nn.Linear(hist_dim, fusion_dim)
        self.attention = MultiHeadAttention(fusion_dim, 8)
        self.fusion_norm = nn.LayerNorm(fusion_dim)
        
    def forward(self, seq_features, hist_features):
        seq_proj = self.seq_projection(seq_features)
        hist_proj = self.hist_projection(hist_features)
        
        # Cross-modal attention
        fusion = self.attention(seq_proj, hist_proj, hist_proj)
        fusion = self.fusion_norm(fusion + seq_proj)
        
        return fusion

3. Region-Specific Prediction Heads

To allow for different prediction strategies in different genomic contexts, I’ve implemented region-specific prediction heads:

  • Promoter-specific prediction head
  • Enhancer-specific prediction head
  • Gene body-specific prediction head
  • Intergenic region-specific head

Performance Improvements

Initial benchmarks show significant improvements:

MetricOriginal ModelNew ArchitectureImprovement
Methylation Prediction MSE0.0850.06227%
AUROC (Classification)0.830.919.6%
Training Time (per epoch)45 min28 min38%
GPU Memory Usage22GB14GB36%

Interpretability Enhancements

The new architecture incorporates several features to improve interpretability:

  1. Attention Visualization: Transformer attention maps reveal sequence-epigenetic interactions
  2. Feature Attribution: Integrated gradients method for attributing predictions to input features
  3. Region-Specific Explanations: Separate explanations for different genomic contexts

Current Challenges

Several challenges remain to be addressed:

  • Further optimization needed for whole-genome scale analysis
  • Integration of additional epigenetic features (ATAC-seq, 3D genome)
  • Handling of sample heterogeneity in cancer datasets
  • Adaptation to limited-data scenarios

Next Development Steps

  1. Implement transfer learning capabilities to leverage pre-trained models
  2. Develop specialized architectures for cancer-specific applications
  3. Create ensemble models that combine predictions from multiple architectures
  4. Design interpretability tools tailored to biologists without ML expertise

Code Availability

The improved model architecture is available in the dev branch of our repository. Documentation and examples are being prepared for the next release.


This is a working document that will be updated as model development continues. Related: WashU Data Integration