Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions gigl/src/common/types/pb_wrappers/gigl_resource_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gigl.common.logger import Logger
from gigl.src.common.constants.components import GiGLComponents
from snapchat.research.gbml.gigl_resource_config_pb2 import (
CustomResourceConfig,
DataflowResourceConfig,
DataPreprocessorConfig,
DistributedTrainerConfig,
Expand Down Expand Up @@ -37,12 +38,14 @@
_KFP_TRAINER_CONFIG = "kfp_trainer_config"
_LOCAL_TRAINER_CONFIG = "local_trainer_config"
_VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG = "vertex_ai_graph_store_trainer_config"
_CUSTOM_TRAINER_CONFIG = "custom_trainer_config"

_INFERENCER_CONFIG_FIELD = "inferencer_config"
_VERTEX_AI_INFERENCER_CONFIG = "vertex_ai_inferencer_config"
_DATAFLOW_INFERENCER_CONFIG = "dataflow_inferencer_config"
_LOCAL_INFERENCER_CONFIG = "local_inferencer_config"
_VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG = "vertex_ai_graph_store_inferencer_config"
_CUSTOM_INFERENCER_CONFIG = "custom_inferencer_config"


@dataclass
Expand All @@ -55,6 +58,7 @@ class GiglResourceConfigWrapper:
KFPResourceConfig,
LocalResourceConfig,
VertexAiGraphStoreConfig,
CustomResourceConfig,
]
] = None
_inference_config: Optional[
Expand All @@ -63,6 +67,7 @@ class GiglResourceConfigWrapper:
VertexAiResourceConfig,
LocalResourceConfig,
VertexAiGraphStoreConfig,
CustomResourceConfig,
]
] = None

Expand Down Expand Up @@ -283,9 +288,10 @@ def trainer_config(
KFPResourceConfig,
LocalResourceConfig,
VertexAiGraphStoreConfig,
CustomResourceConfig,
]:
"""
Returns the trainer config specified in the resource config. (e.g. Vertex AI, KFP, Local)
Returns the trainer config specified in the resource config. (e.g. Vertex AI, KFP, Local, Custom)
"""

if not self._trainer_config:
Expand All @@ -305,6 +311,7 @@ def trainer_config(
KFPResourceConfig,
LocalResourceConfig,
VertexAiGraphStoreConfig,
CustomResourceConfig,
]
if (
deprecated_config.WhichOneof(_TRAINER_CONFIG_FIELD) # type: ignore[arg-type]
Expand Down Expand Up @@ -365,6 +372,11 @@ def trainer_config(
== _VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG
):
_trainer_config = config.vertex_ai_graph_store_trainer_config
elif (
config.WhichOneof(_TRAINER_CONFIG_FIELD) # type: ignore[arg-type]
== _CUSTOM_TRAINER_CONFIG
):
_trainer_config = config.custom_trainer_config
else:
raise ValueError(f"Invalid trainer_config type: {config}")
else:
Expand All @@ -383,9 +395,10 @@ def inferencer_config(
VertexAiResourceConfig,
LocalResourceConfig,
VertexAiGraphStoreConfig,
CustomResourceConfig,
]:
"""
Returns the inferencer config specified in the resource config. (Dataflow)
Returns the inferencer config specified in the resource config. (e.g. Dataflow, Vertex AI, Local, Custom)
"""
if self._inference_config is None:
# TODO: (svij) Marked for deprecation
Expand Down Expand Up @@ -421,6 +434,11 @@ def inferencer_config(
self._inference_config = (
config.vertex_ai_graph_store_inferencer_config
)
elif (
config.WhichOneof(_INFERENCER_CONFIG_FIELD) # type: ignore[arg-type]
== _CUSTOM_INFERENCER_CONFIG
):
self._inference_config = config.custom_inferencer_config
else:
raise ValueError("Invalid inferencer_config type")
else:
Expand Down
15 changes: 15 additions & 0 deletions gigl/src/validation_check/libs/resource_config_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,15 @@ def check_if_trainer_resource_config_valid(
gigl_resource_config_pb2.VertexAiResourceConfig,
gigl_resource_config_pb2.KFPResourceConfig,
gigl_resource_config_pb2.VertexAiGraphStoreConfig,
gigl_resource_config_pb2.CustomResourceConfig,
] = wrapper.trainer_config
if isinstance(trainer_config, gigl_resource_config_pb2.CustomResourceConfig):
logger.info(
"Skipping trainer machine-shape validation: trainer_config is a "
"CustomResourceConfig (launcher-pluggable; no concrete machine "
"spec to validate)."
)
return
_validate_machine_config(config=trainer_config)


Expand All @@ -163,6 +171,13 @@ def check_if_inferencer_resource_config_valid(
resource_config=resource_config_pb
)
inferencer_config = resource_config_wrapper.inferencer_config
if isinstance(inferencer_config, gigl_resource_config_pb2.CustomResourceConfig):
logger.info(
"Skipping inferencer machine-shape validation: inferencer_config "
"is a CustomResourceConfig (launcher-pluggable; no concrete "
"machine spec to validate)."
)
return
_validate_machine_config(config=inferencer_config)


Expand Down
21 changes: 21 additions & 0 deletions proto/snapchat/research/gbml/gigl_resource_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ message VertexAiResourceConfig {
// If unset, and no accelerators are available, will use 1.
int32 compute_cluster_local_world_size = 3;
}

// Lets user-defined launchers be piped in.
// The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so
// leading "KEY=VALUE" assignments parse as inline env vars) with `args`
// appended as positional arguments. Both fields are taken verbatim by
// the dispatcher; any templating or substitution is the caller's
// responsibility (e.g. OmegaConf-resolved at YAML-load time).
message CustomResourceConfig {
// Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments
// are honored by the shell, so callers can inline env vars (e.g.
// "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli").
string command = 1;
// Positional arguments appended after the command. Each element is
// shell-quoted by the dispatcher so values containing spaces/quotes
// survive the shell pass.
// e.g. "[--my_flag=my_value, --my_other_flag=my_other_value, --noskip_training]"
repeated string args = 2;
}

// (deprecated)
// Configuration for distributed training resources
message DistributedTrainerConfig {
Expand All @@ -183,6 +202,7 @@ message TrainerResourceConfig {
KFPResourceConfig kfp_trainer_config = 2;
LocalResourceConfig local_trainer_config = 3;
VertexAiGraphStoreConfig vertex_ai_graph_store_trainer_config = 4;
CustomResourceConfig custom_trainer_config = 5;
}
}

Expand All @@ -193,6 +213,7 @@ message InferencerResourceConfig {
DataflowResourceConfig dataflow_inferencer_config = 2;
LocalResourceConfig local_inferencer_config = 3;
VertexAiGraphStoreConfig vertex_ai_graph_store_inferencer_config = 4;
CustomResourceConfig custom_inferencer_config = 5;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Generated by the Scala Plugin for the Protocol Buffer Compiler.
// Do not edit!
//
// Protofile syntax: PROTO3

package snapchat.research.gbml.gigl_resource_config

/** Lets user-defined launchers be piped in.
* The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so
* leading "KEY=VALUE" assignments parse as inline env vars) with `args`
* appended as positional arguments. Both fields are taken verbatim by
* the dispatcher; any templating or substitution is the caller's
* responsibility (e.g. OmegaConf-resolved at YAML-load time).
*
* @param command
* Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments
* are honored by the shell, so callers can inline env vars (e.g.
* "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli").
* @param args
* Positional arguments appended after the command. Each element is
* shell-quoted by the dispatcher so values containing spaces/quotes
* survive the shell pass.
* e.g. "[--my_flag=my_value, --my_other_flag=my_other_value, --noskip_training]"
*/
@SerialVersionUID(0L)
final case class CustomResourceConfig(
command: _root_.scala.Predef.String = "",
args: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty,
unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty
) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[CustomResourceConfig] {
@transient
private[this] var __serializedSizeMemoized: _root_.scala.Int = 0
private[this] def __computeSerializedSize(): _root_.scala.Int = {
var __size = 0

{
val __value = command
if (!__value.isEmpty) {
__size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value)
}
};
args.foreach { __item =>
val __value = __item
__size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value)
}
__size += unknownFields.serializedSize
__size
}
override def serializedSize: _root_.scala.Int = {
var __size = __serializedSizeMemoized
if (__size == 0) {
__size = __computeSerializedSize() + 1
__serializedSizeMemoized = __size
}
__size - 1

}
def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = {
{
val __v = command
if (!__v.isEmpty) {
_output__.writeString(1, __v)
}
};
args.foreach { __v =>
val __m = __v
_output__.writeString(2, __m)
};
unknownFields.writeTo(_output__)
}
def withCommand(__v: _root_.scala.Predef.String): CustomResourceConfig = copy(command = __v)
def clearArgs = copy(args = _root_.scala.Seq.empty)
def addArgs(__vs: _root_.scala.Predef.String *): CustomResourceConfig = addAllArgs(__vs)
def addAllArgs(__vs: Iterable[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = args ++ __vs)
def withArgs(__v: _root_.scala.Seq[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = __v)
def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v)
def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty)
def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = {
(__fieldNumber: @_root_.scala.unchecked) match {
case 1 => {
val __t = command
if (__t != "") __t else null
}
case 2 => args
}
}
def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = {
_root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor)
(__field.number: @_root_.scala.unchecked) match {
case 1 => _root_.scalapb.descriptors.PString(command)
case 2 => _root_.scalapb.descriptors.PRepeated(args.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector)
}
}
def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this)
def companion: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.type = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig
// @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.CustomResourceConfig])
}

object CustomResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = this
def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = {
var __command: _root_.scala.Predef.String = ""
val __args: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String]
var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null
var _done__ = false
while (!_done__) {
val _tag__ = _input__.readTag()
_tag__ match {
case 0 => _done__ = true
case 10 =>
__command = _input__.readStringRequireUtf8()
case 18 =>
__args += _input__.readStringRequireUtf8()
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
}
_unknownFields__.parseField(tag, _input__)
}
}
snapchat.research.gbml.gigl_resource_config.CustomResourceConfig(
command = __command,
args = __args.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
_root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.")
snapchat.research.gbml.gigl_resource_config.CustomResourceConfig(
command = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""),
args = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty)
)
case _ => throw new RuntimeException("Expected PMessage")
}
def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11)
def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11)
def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number)
lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty
def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber)
lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig(
command = "",
args = _root_.scala.Seq.empty
)
implicit class CustomResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_l) {
def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.command)((c_, f_) => c_.copy(command = f_))
def args: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.args)((c_, f_) => c_.copy(args = f_))
}
final val COMMAND_FIELD_NUMBER = 1
final val ARGS_FIELD_NUMBER = 2
def of(
command: _root_.scala.Predef.String,
args: _root_.scala.Seq[_root_.scala.Predef.String]
): _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig(
command,
args
)
// @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.CustomResourceConfig])
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch
)
case _ => throw new RuntimeException("Expected PMessage")
}
def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11)
def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11)
def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12)
def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12)
def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = {
var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null
(__number: @_root_.scala.unchecked) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res
)
case _ => throw new RuntimeException("Expected PMessage")
}
def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(15)
def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(15)
def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(16)
def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(16)
def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = {
var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null
(__number: @_root_.scala.unchecked) match {
Expand Down
Loading