Model Training (train.py
)¶
Overview¶
The train.py
script handles the complete training pipeline for cpiVAE models using PyTorch Lightning. It supports multiple model architectures, comprehensive logging, checkpointing, and automatic preprocessing artifact saving for downstream inference.
Usage¶
python scripts/train.py --config CONFIG_FILE --platform_a PLATFORM_A_FILE --platform_b PLATFORM_B_FILE [OPTIONS]
Required Arguments¶
--config
: Path to configuration YAML file (e.g.,configs/default.yaml
)--platform_a
: Path to platform A CSV file (training data)--platform_b
: Path to platform B CSV file (training data)
Optional Arguments¶
--output_dir
: Output directory for logs and checkpoints (default:outputs
)--experiment_name
: Name of the experiment (default:joint_vae_experiment
)--version
: Experiment version (default: timestampversion_YYYYMMDD-HHMMSS
)--resume_from_checkpoint
: Path to checkpoint to resume training from--fast_dev_run
: Run a fast development run for debugging
Supported Model Architectures¶
Models¶
joint_vae
: jointVAE with dual encoders/decoders (Recommanded)joint_vae_plus
: Enhanced version with additional tricksJointVAEVampPrior
: VampPrior variant for improved posterior approximationJointIWAE
: Importance Weighted Autoencoder variantJointVQ
: Vector Quantized VAE variantJointMM
: Mixture Model VAE variantres_unet
: ResNet-UNet based direct imputation modelgenerative_vae
: Generative VAE with DDPM or GAN decoder
Note that the base JointVAE performs the best in our testing.
Training Pipeline¶
1. Data Loading and Preprocessing¶
- Loads paired proteomics data from CSV files
- Automatic feature-wise normalization (z-score, min-max, or robust)
- Optional log transformation per platform
- Missing value handling and sample alignment
2. Model Initialization¶
- Automatic input dimension detection
- Model architecture selected via configuration
- Parameter count reporting
3. Training Loop¶
- PyTorch Lightning trainer with GPU/CPU acceleration
- Gradient clipping and learning rate monitoring
- Early stopping based on validation metrics
- Model checkpointing (best and last models)
4. Artifact Saving¶
- Preprocessing scalers (
scaler_a.pkl
,scaler_b.pkl
) - Feature names (
feature_names_a.txt
,feature_names_b.txt
) - Log transformation parameters (
log_params.yaml
) - Configuration backup (
config.yaml
)
Output Structure¶
Training creates the following directory structure:
{output_dir}/
└── {experiment_name}/
└── {version}/
├── checkpoints/
│ ├── {experiment_name}-epoch=XX-val_total_loss=X.XXX.ckpt
│ └── last.ckpt
├── config.yaml
├── scaler_a.pkl
├── scaler_b.pkl
├── feature_names_a.txt
├── feature_names_b.txt
├── log_params.yaml
├── final_model.ckpt
└── tensorboard_logs/
Configuration System¶
Model Configuration¶
model:
model_type: "joint_vae" # Architecture selection
latent_dim: 64 # Latent space dimensionality
encoder_layers: [512, 256, 128] # Encoder architecture
decoder_layers: [128, 256, 512] # Decoder architecture
dropout_rate: 0.2 # Dropout for regularization
activation: "relu" # Activation function
batch_norm: true # Batch normalization
Training Configuration¶
training:
max_epochs: 100
learning_rate: 0.001
batch_size: 64
optimizer: "adam"
gradient_clip_val: 0.5
early_stopping_patience: 10
Hardware Configuration¶
hardware:
accelerator: "auto" # "auto", "gpu", "cpu"
devices: "auto" # Device selection
precision: 32 # Mixed precision training
Examples¶
Basic Training¶
python scripts/train.py \
--config configs/default.yaml \
--platform_a data/olink_overlap_train.csv \
--platform_b data/somascan_overlap_train.csv
Custom Experiment¶
python scripts/train.py \
--config configs/default.yaml \
--platform_a data/olink_overlap_train.csv \
--platform_b data/somascan_overlap_train.csv \
--output_dir outputs_custom \
--experiment_name my_joint_vae \
--version v1.0
Resume Training¶
python scripts/train.py \
--config configs/default.yaml \
--platform_a data/olink_overlap_train.csv \
--platform_b data/somascan_overlap_train.csv \
--resume_from_checkpoint outputs/joint_vae_experiment/version_20250804-120000/checkpoints/last.ckpt
Debug Mode¶
python scripts/train.py \
--config configs/default.yaml \
--platform_a data/olink_overlap_train.csv \
--platform_b data/somascan_overlap_train.csv \
--fast_dev_run
Monitoring and Callbacks¶
TensorBoard Logging¶
View training progress:
tensorboard --logdir outputs/joint_vae_experiment/version_YYYYMMDD-HHMMSS/
Logged metrics include: - Training/validation losses (total, reconstruction, KL, cross-reconstruction) - Cross-imputation correlations - Learning rate schedules - Model gradients and weights
Model Checkpointing¶
- Best model: Saved based on monitored metric (default:
val_total_loss
) - Last model: Most recent checkpoint
- Final model: Saved at training completion
- Configurable via
callbacks.model_checkpoint
settings
Early Stopping¶
- Monitors validation metrics for convergence
- Configurable patience and minimum delta
- Prevents overfitting on small datasets
Advanced Features¶
Multi-GPU Training¶
hardware:
accelerator: "gpu"
devices: [0, 1, 2, 3] # Use specific GPUs
Mixed Precision Training¶
hardware:
precision: 16 # Use 16-bit precision for faster training
Custom Callbacks¶
callbacks:
model_checkpoint:
monitor: "val_cross_a_corr_mean"
mode: "max"
save_top_k: 5
early_stopping:
monitor: "val_cross_a_corr_mean"
mode: "max"
patience: 15
Loss Function Components¶
The composite loss function includes:
- Reconstruction Loss: Platform-specific data reconstruction
- KL Divergence: Regularization of latent space
- Cross-reconstruction Loss: Cross-platform imputation quality
- Latent Alignment Loss: Shared representation learning
Loss weights are configurable via loss_weights
in the config file.
Data Format Requirements¶
Input CSV Format¶
- First column: Sample IDs (must match between platforms)
- Remaining columns: Feature measurements
- Missing values: Handled automatically during preprocessing
Sample Alignment¶
- Samples are automatically aligned by ID
- Mismatched samples are excluded with warnings
Performance Optimization¶
Memory Management¶
- Custom and efficient data loading with PyTorch DataLoader
- Gradient accumulation for large batch simulation
Troubleshooting¶
Common Issues¶
Out of Memory - Reduce batch size in configuration - Use mixed precision training (precision: 16) - Reduce model architecture size
Slow Convergence - Increase learning rate - Adjust loss function weights - Check data normalization
Poor Cross-imputation Performance - Increase cross-reconstruction loss weight - Adjust latent alignment parameters - Verify data quality and preprocessing
Model Selection¶
- Monitor validation metrics, not training loss
- Use cross-validation for small datasets
- Compare multiple model architectures