Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/link_prediction/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ are example inference and training loops for the DBLP dataset. The DBLP dataset
You can follow along with [dblp.ipynb](./dblp.ipynb) to run an e2e GiGL pipeline on the DBLP dataset. It will guide you
through running each component: `config_populator` -> `data_preprocessor` -> `trainer` -> `inferencer`

## Vertex AI TensorBoard

The example trainer configs enable TensorBoard logging with `trainerConfig.shouldLogToTensorboard: true`.

To surface those events in Vertex AI TensorBoard, set `tensorboard_resource_name` on the trainer Vertex resource config,
use a regional bucket, and keep the bucket, CustomJob, and TensorBoard instance in the same region. The attached service
account should have `roles/storage.admin` and `roles/aiplatform.user`.

```{toctree}
:maxdepth: 2
:hidden:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ trainerConfig:
log_every_n_batch: "50" # Frequency in which we log batch information
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
command: python -m examples.link_prediction.homogeneous_training
# To enable cross-job TensorBoard comparison, set
# ``GiglResourceConfig.trainerResourceConfig...tensorboardExperimentName``
# alongside the ``tensorboardResourceName`` on the same resource config.
# See ``proto/snapchat/research/gbml/gigl_resource_config.proto`` for details.
inferencerConfig:
inferencerArgs:
# Example argument to inferencer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ trainer_resource_config:
gpu_type: NVIDIA_TESLA_T4
gpu_limit: 2
num_replicas: 2
tensorboard_resource_name: "projects/USER_PROVIDED_PROJECT/locations/us-central1/tensorboards/USER_PROVIDED_TENSORBOARD_ID"
tensorboard_experiment_name: "user-provided-experiment-name"
inferencer_resource_config:
vertex_ai_inferencer_config:
machine_type: n1-standard-16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ trainer_resource_config:
gpu_type: NVIDIA_TESLA_T4
gpu_limit: 2
num_replicas: 2
tensorboard_resource_name: "projects/USER_PROVIDED_PROJECT/locations/us-central1/tensorboards/USER_PROVIDED_TENSORBOARD_ID"
tensorboard_experiment_name: "user-provided-experiment-name"
inferencer_resource_config:
vertex_ai_graph_store_inferencer_config:
graph_store_pool:
Expand Down
26 changes: 22 additions & 4 deletions examples/link_prediction/graph_store/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -459,12 +460,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
print(f"---Rank {rank} training process set device {device}")
is_chief_process = rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -525,7 +529,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -567,25 +570,35 @@ def _training_process(
if (
batch_idx % args.log_every_n_batch == 0 or batch_idx < 10
): # Log the first 10 batches to ensure the model is initialized correctly
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
print(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
torch.cuda.synchronize()
print(
f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, batch={batch_idx}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
print(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()
flush()

if batch_idx % args.val_every_n_batch == 0:
print(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -596,6 +609,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()
else:
print(f"rank={rank} ended training early - no break condition was met")
Expand Down Expand Up @@ -674,6 +690,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -701,6 +718,7 @@ def _training_process(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
flush()
tensorboard_writer.close()

# Graph store mode cleanup: shutdown the compute process connection to the storage cluster.
shutdown_compute_process()
Expand Down
28 changes: 23 additions & 5 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
| **Inter-process sharing** | N/A (each process loads own partition) | Each process fetches its own shard from |
| | | the storage cluster |
+---------------------------+----------------------------------------------+----------------------------------------------+
| **Cleanup** | ``torch.distributed.destroy_process_group()`` | ``shutdown_compute_process()`` disconnects |
| **Cleanup** | ``torch.distributed.destroy_process_group()`` | ``shutdown_compute_process()`` disconnects |
| | | from storage cluster |
+---------------------------+----------------------------------------------+----------------------------------------------+

Expand Down Expand Up @@ -159,6 +159,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -450,12 +451,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
logger.info(f"---Rank {rank} training process set device {device}")
is_chief_process = rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -517,7 +521,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -555,25 +558,35 @@ def _training_process(
batch_start = time.time()
batch_idx += 1
if batch_idx % args.log_every_n_batch == 0:
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
logger.info(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(
f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
logger.info(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()
flush()

if batch_idx % args.val_every_n_batch == 0:
logger.info(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -582,6 +595,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()

logger.info(f"---Rank {rank} finished training")
Expand Down Expand Up @@ -657,6 +673,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -684,6 +701,7 @@ def _training_process(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
flush()
tensorboard_writer.close()

# Graph store mode cleanup: shutdown the compute process connection to the storage cluster.
shutdown_compute_process()
Expand Down
26 changes: 22 additions & 4 deletions examples/link_prediction/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.tensorboard_writer import TensorBoardWriter

logger = Logger()

Expand Down Expand Up @@ -400,11 +401,15 @@ def _training_process(
if torch.cuda.is_available():
torch.cuda.set_device(device)
logger.info(f"---Rank {rank} training process set device {device}")
is_chief_process = args.machine_rank == 0 and local_rank == 0
tensorboard_writer = TensorBoardWriter.from_env(enabled=is_chief_process)

loss_fn = RetrievalLoss(
loss=torch.nn.CrossEntropyLoss(reduction="mean"),
temperature=0.07,
remove_accidental_hits=True,
)
batch_idx = 0

if not args.should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
Expand Down Expand Up @@ -469,7 +474,6 @@ def _training_process(

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
last_n_batch_time: list[float] = []
Expand Down Expand Up @@ -509,26 +513,35 @@ def _training_process(
batch_start = time.time()
batch_idx += 1
if batch_idx % args.log_every_n_batch == 0:
mean_batch_time = statistics.mean(last_n_batch_time)
mean_train_loss = statistics.mean(last_n_batch_avg_loss)
logger.info(
f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}"
)
if torch.cuda.is_available():
# Wait for GPU operations to finish
torch.cuda.synchronize()
logger.info(
f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
f"rank={rank}, batch={batch_idx}, mean(batch_time)={mean_batch_time:.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec"
)
tensorboard_writer.log(
{
"Time/batch_mean_sec": mean_batch_time,
"Loss/train": mean_train_loss,
},
step=batch_idx,
)
last_n_batch_time.clear()
# log the global average training loss
logger.info(
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}"
f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={mean_train_loss:.6f}"
)
last_n_batch_avg_loss.clear()

if batch_idx % args.val_every_n_batch == 0:
logger.info(f"rank={rank}, batch={batch_idx}, validating...")
model.eval()
_run_validation_loops(
global_avg_val_loss = _run_validation_loops(
model=model,
main_loader=val_main_loader_iter,
random_negative_loader=val_random_negative_loader_iter,
Expand All @@ -538,6 +551,9 @@ def _training_process(
log_every_n_batch=args.log_every_n_batch,
num_batches=num_val_batches_per_process,
)
tensorboard_writer.log(
{"Loss/val": global_avg_val_loss}, step=batch_idx
)
model.train()

logger.info(f"---Rank {rank} finished training")
Expand Down Expand Up @@ -619,6 +635,7 @@ def _training_process(
device=device,
log_every_n_batch=args.log_every_n_batch,
)
tensorboard_writer.log({"Loss/test": global_avg_test_loss}, step=batch_idx)

# Memory cleanup and waiting for all processes to finish
if torch.cuda.is_available():
Expand Down Expand Up @@ -648,6 +665,7 @@ def _training_process(
logger.info(
f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds"
)
tensorboard_writer.close()

torch.distributed.destroy_process_group()

Expand Down
Loading