Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
187bd3a
Update generated protobuf artifacts
Apr 28, 2026
5edfedd
Enable Vertex AI TensorBoard for trainer jobs
Apr 28, 2026
cde7d00
test(tensorboard): add failing tests for TensorBoardWriter class
Apr 29, 2026
ddc4fcd
test(tensorboard): cover close idempotency on no-op writer
Apr 29, 2026
a4e1bca
refactor(tensorboard): replace function API with TensorBoardWriter class
Apr 29, 2026
e8037c5
refactor(v1-trainer): remove dead tensorboard writer plumbing
Apr 29, 2026
d6e1f9f
refactor(examples): migrate homogeneous_training.py to TensorBoardWriter
Apr 29, 2026
2018f23
refactor(examples): migrate heterogeneous_training.py to TensorBoardW…
Apr 29, 2026
89939de
refactor(examples): migrate graph_store/homogeneous_training.py to Te…
Apr 29, 2026
d9817f4
refactor(examples): migrate graph_store/heterogeneous_training.py to …
Apr 29, 2026
fe879dc
cleanup
Apr 29, 2026
8ff3012
chore(examples): drop shouldLogToTensorboard: true from task configs
Apr 29, 2026
f811f0a
refactor(tensorboard): move TensorBoardWriter to gigl/utils
Apr 29, 2026
e97bd64
refactor(examples): drop should_log_to_tensorboard gate
Apr 29, 2026
187123c
refactor(trainer): always forward tensorboard_logs_uri when proto is set
Apr 29, 2026
3b7133c
refactor(tensorboard): replace from_uri with from_env
Apr 30, 2026
96ea473
refactor(examples): drop tensorboard URI plumbing in homogeneous_trai…
Apr 30, 2026
23f0aae
refactor(examples): drop tensorboard URI plumbing from training examples
Apr 30, 2026
1c14a16
Update
May 1, 2026
667df9b
proto: add TrainerConfig.tensorboard_experiment_name
May 4, 2026
cd40efd
validation: tensorboard_experiment_name requires tensorboard_resource…
May 4, 2026
2a78d5a
vertex_ai: add tensorboard_experiment_name to VertexAiJobConfig
May 4, 2026
b6b70a7
vertex_ai: helper to ensure Experiment exists with backing TB
May 4, 2026
4868e04
vertex_ai: submit with experiment when tensorboard_experiment_name is…
May 4, 2026
9d3109e
launcher: thread tensorboard_experiment_name through _build_job_config
May 4, 2026
1ffd750
launcher: thread tensorboard_experiment_name through launch entrypoints
May 4, 2026
07784ed
test(launcher): tighten experiment_name negative assertions to assert…
May 4, 2026
a1d73e9
trainer: forward TrainerConfig.tensorboard_experiment_name to launcher
May 4, 2026
8dc0b49
examples: demo TrainerConfig.tensorboardExperimentName
May 4, 2026
99ab56d
validation: tensorboard_experiment_name also requires tensorboard_log…
May 4, 2026
73103ed
Revert "validation: tensorboard_experiment_name also requires tensorb…
May 4, 2026
fcc871d
vertex_ai: sanitize job_name for ExperimentRun ID, validate experimen…
May 4, 2026
be5bbf0
vertex_ai: drop experiment_run kwarg, let SDK auto-generate the run
May 4, 2026
e19f105
tensorboard: stream events from chief rank, drop submit(experiment=) …
May 4, 2026
fea1a9d
vertex_ai: always pass tensorboard= so VAI job page links to TB
May 5, 2026
31d3a35
tensorboard: emit unique run names so multi-job comparison shows two …
May 5, 2026
a5048fd
examples: scope TensorBoardWriter to a try/finally block in all train…
May 5, 2026
95981e7
tools: add dev_submit_tb_smoke_job + tb_smoke_main for fast TB iteration
May 5, 2026
51d9df7
tools: relocate smoke launcher to gigl.utils.dev (tools/ is gitignored)
May 5, 2026
f06df02
smoke: bump machine_type to n1-standard-16 (n1-standard-2 unsupported)
May 5, 2026
5b56f03
vertex_ai: log TensorBoard URLs at submit time
May 5, 2026
50ef84c
examples: drop personal experiment name from e2e CORA task config
May 5, 2026
dd88f7c
dev: remove gigl/utils/dev/ smoke tooling
May 5, 2026
e5d0cf9
docs: remove in-flight branch plan from git
May 5, 2026
36798ed
vertex_ai: collapse submit_kwargs dict back to direct kwargs
May 5, 2026
24e66be
v1: revert all v1 trainer changes — out of scope for this PR
May 5, 2026
dd55744
docs: add Vertex AI doc-link references to TB code paths
May 5, 2026
5e4d9b2
proto: move tensorboard_experiment_name to VertexAiResourceConfig
May 5, 2026
236275e
examples: update stale CORA task-config comment to point at the new p…
May 5, 2026
32a5d24
launcher: revert use_cuda + storage container_uri fixes — separate PR
May 5, 2026
b17e8d1
examples: drop try/finally around training loop, call close() at end
May 5, 2026
dfbcc04
tensorboard_writer: hoist aiplatform import to module top
May 5, 2026
5d82060
tests: drop pure-mock test files
May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ run_all_e2e_tests:
# Example:
# `make compiled_pipeline_path="/tmp/gigl/my_pipeline.yaml" compile_gigl_kubeflow_pipeline`
# Can be a GCS URI as well
compile_gigl_kubeflow_pipeline: compile_jars push_new_docker_images
compile_gigl_kubeflow_pipeline: push_new_docker_images
uv run python -m gigl.orchestration.kubeflow.runner \
--action=compile \
--container_image_cuda=${DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG} \
Expand Down
8 changes: 8 additions & 0 deletions examples/link_prediction/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ are example inference and training loops for the DBLP dataset. The DBLP dataset
You can follow along with [dblp.ipynb](./dblp.ipynb) to run an e2e GiGL pipeline on the DBLP dataset. It will guide you
through running each component: `config_populator` -> `data_preprocessor` -> `trainer` -> `inferencer`

## Vertex AI TensorBoard

The example trainer configs enable TensorBoard logging with `trainerConfig.shouldLogToTensorboard: true`.

To surface those events in Vertex AI TensorBoard, set `tensorboard_resource_name` on the trainer Vertex resource config,
use a regional bucket, and keep the bucket, CustomJob, and TensorBoard instance in the same region. The attached service
account should have `roles/storage.admin` and `roles/aiplatform.user`.

```{toctree}
:maxdepth: 2
:hidden:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ trainerConfig:
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
command: python -m examples.link_prediction.homogeneous_training
# To enable cross-job TensorBoard comparison, set
# ``GiglResourceConfig.trainerResourceConfig...tensorboardExperimentName``
# alongside the ``tensorboardResourceName`` on the same resource config.
# See ``proto/snapchat/research/gbml/gigl_resource_config.proto`` for details.
inferencerConfig:
inferencerArgs:
# Example argument to inferencer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ trainer_resource_config:
gpu_type: NVIDIA_TESLA_T4
gpu_limit: 2
num_replicas: 2
tensorboard_resource_name: "projects/USER_PROVIDED_PROJECT/locations/us-central1/tensorboards/USER_PROVIDED_TENSORBOARD_ID"
inferencer_resource_config:
vertex_ai_inferencer_config:
machine_type: n1-standard-16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ trainer_resource_config:
gpu_type: NVIDIA_TESLA_T4
gpu_limit: 2
num_replicas: 2
tensorboard_resource_name: "projects/USER_PROVIDED_PROJECT/locations/us-central1/tensorboards/USER_PROVIDED_TENSORBOARD_ID"
inferencer_resource_config:
vertex_ai_graph_store_inferencer_config:
graph_store_pool:
Expand Down
26 changes: 22 additions & 4 deletions examples/link_prediction/graph_store/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -459,12 +460,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
print(f"---Rank {rank} training process set device {device}")
is_chief_process = rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -525,7 +529,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -567,25 +570,35 @@ def _training_process(
if (
batch_idx % args.log_every_n_batch == 0 or batch_idx < 10
): # Log the first 10 batches to ensure the model is initialized correctly
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
print(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
torch.cuda.synchronize()
print(
f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, batch={batch_idx}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
print(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()
flush()

if batch_idx % args.val_every_n_batch == 0:
print(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -596,6 +609,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()
else:
print(f"rank={rank} ended training early - no break condition was met")
Expand Down Expand Up @@ -674,6 +690,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -701,6 +718,7 @@ def _training_process(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
flush()
tensorboard_writer.close()

# Graph store mode cleanup: shutdown the compute process connection to the storage cluster.
shutdown_compute_proccess()
Expand Down
26 changes: 22 additions & 4 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -450,12 +451,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
logger.info(f"---Rank {rank} training process set device {device}")
is_chief_process = rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -517,7 +521,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -555,25 +558,35 @@ def _training_process(
batch_start = time.time()
batch_idx += 1
if batch_idx % args.log_every_n_batch == 0:
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
logger.info(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(
f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
logger.info(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()
flush()

if batch_idx % args.val_every_n_batch == 0:
logger.info(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -582,6 +595,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()

logger.info(f"---Rank {rank} finished training")
Expand Down Expand Up @@ -657,6 +673,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -684,6 +701,7 @@ def _training_process(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
flush()
tensorboard_writer.close()

# Graph store mode cleanup: shutdown the compute process connection to the storage cluster.
shutdown_compute_proccess()
Expand Down
26 changes: 22 additions & 4 deletions examples/link_prediction/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -400,11 +401,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
logger.info(f"---Rank {rank} training process set device {device}")
is_chief_process = args.machine_rank == 0 and local_rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -469,7 +474,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -509,26 +513,35 @@ def _training_process(
batch_start = time.time()
batch_idx += 1
if batch_idx % args.log_every_n_batch == 0:
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
logger.info(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
# Wait for GPU operations to finish
torch.cuda.synchronize()
logger.info(
f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, batch={batch_idx}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
logger.info(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()

if batch_idx % args.val_every_n_batch == 0:
logger.info(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -538,6 +551,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()

logger.info(f"---Rank {rank} finished training")
Expand Down Expand Up @@ -619,6 +635,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -648,6 +665,7 @@ def _training_process(
logger.info(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
tensorboard_writer.close()

torch.distributed.destroy_process_group()

Expand Down
Loading