Skip to content

Graylab/GermRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GermRL

GermRL is a reinforcement learning framework to alleviate the Germline Bias in autoregressive antibody language models. It implements a custom GRPO (Generalized Reward-based Policy Optimization) algorithm where prompts begin from the start token. Repository implementation supports training and inference from the Hugging Face ProGen2-OAS implementation by Hrbáň et al., adapted from Nijkamp et al..

Installation

  1. Clone the repository:
git clone https://github.com/yourusername/GermRL.git
cd GermRL
  1. Create the conda environment using the provided environment file (mamba recommended):
mamba env create -f environment.yaml
conda activate germrl

Training

Use one of the example YAML files in train/ as a starting point. Recommended examples:

  • train/LD25_training.yaml for single LD threshold
  • train/LD15_25training.yaml for bound LD threshold

Run training with the configuration file:

  1. Run the training script:
python train/GermRL_train.py --config train/LD5_training.yaml

Training Parameters

Key parameters in the YAML config for training:

  • output_dir: str. Directory for training visualizations and logs
  • weight_output_dir: str. Directory to save model checkpoints
  • anarci_outfile: str. File path to store intermediate ANARCI results
  • log_file: str. File path to logging file
  • lr: float. Learning rate
  • t: float. Temperature for ProGen2-OAS sampling
  • p: float. ProGen2-OAS Top-p (nucleus) sampling parameter
  • GRPO_epochs: int. Number of training epochs
  • target_thresh_LD: int. Target Levenshtein Distance threshold
  • target_thresh_fold: float. Target pLDDT threshold
  • num_traj: int. Number of trajectories to generate per training step
  • steps: int. Number of steps per epoch
  • num_updates: int. Number of update iterations per step
  • ref_coef: float. Weight assigned to the KL penalty from reference policy
  • entropy_coef: float. Weight assigned to the entropy penalty
  • clip_higher: float. Increase to upper clipping parameter
  • dataset_prompts: list[str]. Prompts to sample during GRPO, for ProGen2-OAS application, kept as single prompt ["1"], the start character
  • num_prompts_sample: int. Number of prompts to sample per step, by default set to 1.

Paramters to control GRPO modifications:

  • training_scheme_default: bool. False (default) sets weight synchronization every epoch rather than every step (1st modification) and sample exclusively from updating policy (2nd modification). Set to True to exclude GRPO modifications.

Parameters to run bounded LD objective:

  • bound_thresh_ld: int. Upper LD threshold bound
  • RGT: bool. Set to True to apply RGT reward

Generation

Generate antibody sequences using a trained model. ProGen2-RL model weights available on Hugging Face (Example: "lludwig2/GermRL-LD5"). For base model use "hugohrban/progen2-oas".

python gen/generate.py \
    --base_path outputs/example_output_gen_viz \
    --model_weights "lludwig2/GermRL-LD5-LD15" \
    --logger_file outputs/gen_logs/generation.log \
    --temps 1.0 \
    --num_gen 20 \
    --max_len 250 \
    --top_p_val 0.9 \
    --anarci_path anarci_out.txt

Generation Parameters

  • --base_path: str. Output directory for generated sequences and analysis
  • --model_weights: str. Path to model weights.
  • --logger_file: str. Path to log file
  • --temps: float. Temperatures for generation (multiple allowed)
  • --num_gen: int. Number of sequences to generate
  • --max_len: int. Maximum sequence length
  • --top_p_val: float. Top-p sampling parameter
  • --anarci_path : str. Path for intermediate file containing ANARCI results

Output Files

Generation creates several output files in the specified base directory:

  • seq.csv: Generated antibody sequences
  • LV_Fold.csv: Sequences with LD and folding scores
  • v_calls.json: V gene assignments
  • j_calls.json: J gene assignments
  • seq_num_redo.json: Regeneration statistics for improper generations

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages