Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions swift/arguments/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .data_args import DataArguments
from .generation_args import GenerationArguments
from .model_args import ModelArguments
from .profile_args import ProfilerArguments
from .quant_args import QuantizeArguments
from .template_args import TemplateArguments

Expand All @@ -30,12 +31,11 @@ def get_supported_tuners():

@dataclass
class BaseArguments(GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments,
RayArguments):
RayArguments, ProfilerArguments):
"""BaseArguments class is a dataclass that inherits from multiple argument classes.

This class consolidates arguments from GenerationArguments, QuantizeArguments, DataArguments,
TemplateArguments, ModelArguments, RayArguments.

TemplateArguments, ModelArguments, RayArguments, and ProfilerArguments.
Args:
tuner_backend (str): The tuner backend to use. Choices are 'peft' or 'unsloth'. Default is 'peft'.
tuner_type (str): The tuner type. Choices include 'lora', 'full', 'longlora', 'adalora', 'llamapro',
Expand Down Expand Up @@ -171,6 +171,7 @@ def __post_init__(self):
TemplateArguments.__post_init__(self)
DataArguments.__post_init__(self)
RayArguments.__post_init__(self)
ProfilerArguments.__post_init__(self)
self._init_stream()
if self.max_length is None and self.model_info is not None:
self.max_length = self.model_info.max_model_len
Expand Down
55 changes: 55 additions & 0 deletions swift/arguments/base_args/profile_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from dataclasses import dataclass, field
from typing import List, Optional

from swift.utils import get_logger

logger = get_logger()


@dataclass
class ProfilerArguments:

enable_profiler: bool = False
profiler_save_path: Optional[str] = None
profiler_all_ranks: bool = False
profiler_ranks: List[int] = field(default_factory=list)
profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape"
profiler_discrete: bool = False
profiler_tool: Optional[str] = 'torch'
profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile

def __post_init__(self):
assert not self.profiler_discrete, \
'Profiler discrete mode is not supported yet, please set profiler_discrete to false'
if self.enable_profiler and 'profiler' not in self.callbacks:
self.callbacks.append('profiler')
if 'profiler' in self.callbacks and not self.enable_profiler:
self.enable_profiler = True
if self.enable_profiler:
assert self.profiler_save_path is not None, \
'Profiler save path must be specified when profiler is enabled.'
assert self.profiler_contents, \
'Profiler contents must be specified when profiler is enabled.'
assert self.profiler_steps, \
'Profiler steps must be specified when profiler is enabled.'
assert self.profiler_ranks != [] or self.profiler_all_ranks, \
'Either profiler_ranks must be specified or profiler_all_ranks must be set to True.'
Comment thread
qq1243196045 marked this conversation as resolved.
if self.enable_profiler:
assert 'profiler' in self.callbacks, \
'Profiler callback must be included in callbacks when profiler is enabled.'
if 'profiler' in self.callbacks:
assert self.enable_profiler, \
'Profiler callback is included in callbacks but profiler is not enabled.'
Comment thread
qq1243196045 marked this conversation as resolved.
Outdated

def get_profiler_kwargs(self):
return {
'enable_profiler': self.enable_profiler,
'profiler_save_path': self.profiler_save_path,
'profiler_all_ranks': self.profiler_all_ranks,
'profiler_ranks': self.profiler_ranks,
'profiler_contents': self.profiler_contents,
'profiler_discrete': self.profiler_discrete,
'profiler_tool': self.profiler_tool,
'profiler_steps': self.profiler_steps,
}
4 changes: 3 additions & 1 deletion swift/callbacks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .early_stop import EarlyStopCallback
from .lisa import LISACallback
from .perf_log import PerfMetricsLogCallback
from .profiler import ProfilerCallback

callbacks_map = {
'activation_cpu_offload': ActivationCpuOffloadCallBack,
Expand All @@ -13,5 +14,6 @@
'early_stop': EarlyStopCallback,
'graceful_exit': GracefulExitCallback,
'lisa': LISACallback,
'perf_log': PerfMetricsLogCallback
'perf_log': PerfMetricsLogCallback,
'profiler': ProfilerCallback,
}
26 changes: 26 additions & 0 deletions swift/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from transformers.trainer_callback import ProgressCallback, TrainerControl, TrainerState

from swift.utils import get_logger
from swift.utils.profiler import DistProfiler

logger = get_logger()


class ProfilerCallback(ProgressCallback):

def __init__(self, args, trainer):
super().__init__()
self.args = args
self.trainer = trainer
self.trainer.profiler = DistProfiler(global_config=args)

def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self.args.profiler_steps and state.global_step in self.args.profiler_steps:
self.trainer.profiler.start()
super().on_step_begin(args, state, control, **kwargs)

def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self.args.profiler_steps and state.global_step + 1 not in self.args.profiler_steps:
self.trainer.profiler.stop()
super().on_step_end(args, state, control, **kwargs)
2 changes: 2 additions & 0 deletions swift/megatron/callbacks/mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .default_flow import DefaultFlowCallback
from .print import PrintCallback
from .profiler import ProfilerCallback
from .swanlab import SwanlabCallback
from .tensorboard import TensorboardCallback
from .wandb import WandbCallback
Expand All @@ -11,4 +12,5 @@
'swanlab': SwanlabCallback,
'wandb': WandbCallback,
'tensorboard': TensorboardCallback,
'profiler': ProfilerCallback,
}
23 changes: 23 additions & 0 deletions swift/megatron/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from swift.utils import get_logger
from swift.utils.profiler import DistProfiler
from .base import MegatronCallback

logger = get_logger()


class ProfilerCallback(MegatronCallback):

def __init__(self, trainer):
super().__init__(trainer)
self.args = trainer.args
self.trainer = trainer
self.trainer.profiler = DistProfiler(global_config=self.args)

def on_step_begin(self):
if self.args.profiler_steps and self.state.global_step in self.args.profiler_steps:
self.trainer.profiler.start()

def on_step_end(self):
if self.args.profiler_steps and self.state.global_step + 1 not in self.args.profiler_steps:
self.trainer.profiler.stop()
20 changes: 20 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ class TrainArgumentsMixin:
shared memory and then asynchronously persisted to disk. Currently does not support the safetensors format.
It is recommended to use this with `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"` to prevent CUDA OOM
errors during training. Defaults to False.

enable_profiler (bool): Master switch to enable or disable performance profiling. Default is False.
profiler_save_path (Optional[str]): Directory path where the profiling results and trace files will be saved.
profiler_all_ranks (bool): If True, collects profiling data from all distributed processes.
profiler_ranks (List[int]): A list of specific rank IDs to collect profiling data from.
profiler_contents (List[str]): List of data categories to record, such as "cpu", "cuda", "memory", or "stack".
profiler_discrete (bool): If True, records data for each step independently.
profiler_tool (Optional[str]): Specifies the backend tool used for profiling.
profiler_steps (Optional[List[int]]): A list of specific training steps during which profiling should be active.

"""
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
Expand Down Expand Up @@ -202,6 +212,16 @@ class TrainArgumentsMixin:
# dlrover flash_checkpoint
use_flash_ckpt: bool = False

# profiler
enable_profiler: bool = False
profiler_save_path: Optional[str] = None
profiler_all_ranks: bool = False
profiler_ranks: List[int] = field(default_factory=list)
profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape"
profiler_discrete: bool = False
profiler_tool: Optional[str] = None
Comment thread
qq1243196045 marked this conversation as resolved.
Outdated
profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile

@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
Expand Down
7 changes: 7 additions & 0 deletions swift/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig

__all__ = [
'DistProfiler',
'DistProfilerExtension',
'ProfilerConfig',
]
142 changes: 142 additions & 0 deletions swift/utils/profiler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import collections
from dataclasses import FrozenInstanceError, dataclass, field, fields
from typing import Any, Optional

# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary


@dataclass
class BaseConfig(collections.abc.Mapping):
"""The BaseConfig provides dict-like interface for a dataclass config.

By default all fields in the config is not mutable, unless specified in
"_mutable_fields". The BaseConfig class implements the Mapping Abstract Base Class.
This allows instances of this class to be used like dictionaries.
"""

_mutable_fields = set()
_target_: str = ''

def __setattr__(self, name: str, value):
"""Set the value of an attribute. Check if the attr is mutable before setting the value."""
# If the field already exists, it's considered frozen unless it's in _mutable_fields
if name in self.__dict__ and name not in getattr(self, '_mutable_fields', set()):
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
super().__setattr__(name, value)

def get(self, key: str, default: Any = None) -> Any:
"""Get the value associated with the given key. If the key does not exist, return the default value.

Args:
key (str): The attribute name to retrieve.
default (Any, optional): The value to return if the attribute does not exist. Defaults to None.

Returns:
Any: The value of the attribute or the default value.
"""
try:
return getattr(self, key)
except AttributeError:
return default

def __getitem__(self, key: str):
"""Implement the [] operator for the class. Allows accessing attributes like dictionary items.

Args:
key (str): The attribute name to retrieve.

Returns:
Any: The value of the attribute.

Raises:
AttributeError: If the attribute does not exist.
TypeError: If the key type is not string
"""
return getattr(self, key)

def __iter__(self):
"""Implement the iterator protocol. Allows iterating over the attribute names of the instance.

Yields:
str: The name of each field in the dataclass.
"""
for f in fields(self):
yield f.name

def __len__(self):
"""
Return the number of fields in the dataclass.

Returns:
int: The number of fields in the dataclass.
"""
return len(fields(self))


@dataclass
class ProfilerConfig(BaseConfig):
"""Worker profiler config.

Args:
discrete (bool): True for each task has its own database, False for all tasks in one training step
share one database.
all_ranks (bool): Whether to profile all ranks.
ranks (list[int]): The ranks that will be profiled. Defaults to [].
global_tool_config (Any): Global tool configuration for all profiling tools.
"""

tool: Optional[str] = None
enable: bool = False
all_ranks: bool = False
ranks: list[int] = field(default_factory=list)
save_path: Optional[str] = None
tool_config: Any = None
global_tool_config: Optional[Any] = None # Global tool configuration for all profiling tools

def union(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, f"Cannot union ProfilerConfig with different tools: {self.tool} vs {other.tool}"
return ProfilerConfig(
tool=self.tool,
enable=self.enable or other.enable,
all_ranks=self.all_ranks or other.all_ranks,
ranks=list(set(self.ranks or []) | set(other.ranks or [])),
save_path=self.save_path,
tool_config=self.tool_config,
global_tool_config=self.global_tool_config or other.global_tool_config,
)
Comment thread
qq1243196045 marked this conversation as resolved.

def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path,
tool_config=self.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)
Comment on lines +108 to +119

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the intersect method, if self.save_path or self.tool_config is None, it should fallback to other.save_path or other.tool_config instead of strictly using self's values which might be None.

Suggested change
def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path,
tool_config=self.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)
def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path or other.save_path,
tool_config=self.tool_config or other.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)


def __post_init__(self) -> None:
"""config validation logics go here"""
assert isinstance(self.ranks,
set | list | tuple), (f"Profiler ranks must be of type list, got {type(self.ranks)}")
Comment thread
qq1243196045 marked this conversation as resolved.
Outdated


@dataclass
class TorchProfilerToolConfig(BaseConfig):
"""Torch profiler tool config."""

# options: cuda, cpu, memory, shapes, stack
contents: list[str] = field(default_factory=list)
discrete: bool = False
name: str = 'torch'

def __post_init__(self) -> None:
"""config validation logics go here"""
__support_contents = ['cuda', 'cpu', 'memory', 'shapes', 'stack']
for content in self.contents:
assert content in __support_contents, (
f"Profiler contents only supports {__support_contents}, but gets {content}")
assert isinstance(self.contents, list), f"Profiler contents must be of type list, got {type(self.contents)}"
Comment thread
qq1243196045 marked this conversation as resolved.
Outdated
Loading
Loading