GermRL is a reinforcement learning framework to alleviate the Germline Bias in autoregressive antibody language models. It implements a custom GRPO (Generalized Reward-based Policy Optimization) algorithm where prompts begin from the start token. Repository implementation supports training and inference from the Hugging Face ProGen2-OAS implementation by Hrbáň et al., adapted from Nijkamp et al..
- Clone the repository:
git clone https://github.com/yourusername/GermRL.git
cd GermRL- Create the conda environment using the provided environment file (mamba recommended):
mamba env create -f environment.yaml
conda activate germrlUse one of the example YAML files in train/ as a starting point. Recommended examples:
train/LD25_training.yamlfor single LD thresholdtrain/LD15_25training.yamlfor bound LD threshold
Run training with the configuration file:
- Run the training script:
python train/GermRL_train.py --config train/LD5_training.yamlKey parameters in the YAML config for training:
output_dir: str. Directory for training visualizations and logsweight_output_dir: str. Directory to save model checkpointsanarci_outfile: str. File path to store intermediate ANARCI resultslog_file: str. File path to logging filelr: float. Learning ratet: float. Temperature for ProGen2-OAS samplingp: float. ProGen2-OAS Top-p (nucleus) sampling parameterGRPO_epochs: int. Number of training epochstarget_thresh_LD: int. Target Levenshtein Distance thresholdtarget_thresh_fold: float. Target pLDDT thresholdnum_traj: int. Number of trajectories to generate per training stepsteps: int. Number of steps per epochnum_updates: int. Number of update iterations per stepref_coef: float. Weight assigned to the KL penalty from reference policyentropy_coef: float. Weight assigned to the entropy penaltyclip_higher: float. Increase to upper clipping parameterdataset_prompts: list[str]. Prompts to sample during GRPO, for ProGen2-OAS application, kept as single prompt ["1"], the start characternum_prompts_sample: int. Number of prompts to sample per step, by default set to 1.
Paramters to control GRPO modifications:
training_scheme_default: bool. False (default) sets weight synchronization every epoch rather than every step (1st modification) and sample exclusively from updating policy (2nd modification). Set to True to exclude GRPO modifications.
Parameters to run bounded LD objective:
bound_thresh_ld: int. Upper LD threshold boundRGT: bool. Set to True to apply RGT reward
Generate antibody sequences using a trained model. ProGen2-RL model weights available on Hugging Face (Example: "lludwig2/GermRL-LD5"). For base model use "hugohrban/progen2-oas".
python gen/generate.py \
--base_path outputs/example_output_gen_viz \
--model_weights "lludwig2/GermRL-LD5-LD15" \
--logger_file outputs/gen_logs/generation.log \
--temps 1.0 \
--num_gen 20 \
--max_len 250 \
--top_p_val 0.9 \
--anarci_path anarci_out.txt--base_path: str. Output directory for generated sequences and analysis--model_weights: str. Path to model weights.--logger_file: str. Path to log file--temps: float. Temperatures for generation (multiple allowed)--num_gen: int. Number of sequences to generate--max_len: int. Maximum sequence length--top_p_val: float. Top-p sampling parameter--anarci_path: str. Path for intermediate file containing ANARCI results
Generation creates several output files in the specified base directory:
seq.csv: Generated antibody sequencesLV_Fold.csv: Sequences with LD and folding scoresv_calls.json: V gene assignmentsj_calls.json: J gene assignmentsseq_num_redo.json: Regeneration statistics for improper generations