Source code for lightning.pytorch.loggers.litlogger

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fabric/PyTorch Lightning logger that enables remote experiment tracking, logging, and artifact management on
lightning.ai."""

import logging
import os
import warnings
from argparse import Namespace
from collections.abc import Mapping
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.logger import _add_prefix
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.utilities import _scan_checkpoints

if TYPE_CHECKING:
    from litlogger import Experiment

log = logging.getLogger(__name__)

_LITLOGGER_AVAILABLE = RequirementCache("litlogger>=0.1.0")


def _create_experiment_name() -> str:
    """Create a random experiment name using litlogger's generator."""
    from litlogger.generator import _create_name

    return _create_name()


[docs]class LitLogger(Logger): """Logger that enables remote experiment tracking, logging, and artifact management on lightning.ai.""" LOGGER_JOIN_CHAR = "-" def __init__( self, root_dir: Optional[_PATH] = None, name: Optional[str] = None, teamspace: Optional[str] = None, metadata: Optional[dict[str, str]] = None, store_step: bool = True, log_model: bool = False, save_logs: bool = True, checkpoint_name: Optional[str] = None, ) -> None: """Initialize the LightningLogger. Args: root_dir: Folder where logs and metadata are stored (default: ./lightning_logs). name: Name of your experiment (defaults to a generated name). teamspace: Teamspace name where charts and artifacts will appear. metadata: Extra metadata to associate with the experiment as tags. log_model: If True, automatically log model checkpoints as artifacts. save_logs: If True, capture and upload terminal logs. checkpoint_name: Override the base name for logged checkpoints. Example:: from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, BoringDataModule from lightning.pytorch.loggers.litlogger import LitLogger class LoggingModel(BoringModel): def training_step(self, batch, batch_idx: int): loss = self.step(batch) # logging the computed loss self.log("train_loss", loss) return {"loss": loss} trainer = Trainer( max_epochs=10, enable_model_summary=False, logger=LitLogger("./lightning_logs", name="boring_model") ) model = BoringModel() data_module = BoringDataModule() trainer.fit(model, data_module) trainer.test(model, data_module) """ self._root_dir = os.fspath(root_dir or "./lightning_logs") self._name = name or _create_experiment_name() self._version: Optional[str] = None self._teamspace = teamspace self._sub_dir = None self._prefix = "" self._fs = get_filesystem(self._root_dir) self._experiment: Optional[Experiment] = None self._step = -1 self._metadata = metadata or {} self._is_ready = False self._log_model = log_model self._save_logs = save_logs self._checkpoint_callback: Optional[ModelCheckpoint] = None self._logged_model_time: dict[str, float] = {} self._checkpoint_name = checkpoint_name # ────────────────────────────────────────────────────────────────────────────── # Properties # ────────────────────────────────────────────────────────────────────────────── @property @override def name(self) -> str: """Gets the name of the experiment.""" return self._name @property @override def version(self) -> Optional[str]: """Get the experiment version - its time of creation.""" return self._version @property @override def root_dir(self) -> str: """Gets the save directory where the litlogger experiments are saved.""" return self._root_dir @property @override def log_dir(self) -> str: """The directory for this run's tensorboard checkpoint. By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. """ log_dir = os.path.join(self.root_dir, self.name) if isinstance(self.sub_dir, str): log_dir = os.path.join(log_dir, self.sub_dir) log_dir = os.path.expandvars(log_dir) return os.path.expanduser(log_dir) @property def save_dir(self) -> str: return self.log_dir @property def sub_dir(self) -> Optional[str]: """Gets the sub directory where the TensorBoard experiments are saved.""" return self._sub_dir @property def _experiment_name(self) -> str: if self.version is None: return self.name return f"{self.name}-{self.version}" @staticmethod def _default_artifact_key(path: str) -> str: try: rel = os.path.relpath(path) except ValueError: rel = None key = rel if rel is not None and not rel.startswith("..") else os.path.basename(path) return key.replace("\\", "/") def _model_key(self) -> str: return self._experiment_name @staticmethod def _model_version(version: Optional[str], step: Optional[int]) -> Optional[str]: if version is not None: return version if step is not None and step >= 0: return str(step) return None @property @rank_zero_experiment def experiment(self) -> Optional["Experiment"]: """Returns the underlying litlogger Experiment object.""" import litlogger if self._experiment is not None: return self._experiment if not self._is_ready: self._is_ready = True assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" if self.root_dir: self._fs.makedirs(self.root_dir, exist_ok=True) if self.version is None: # Generate version as proper RFC 3339 timestamp with Z suffix (required by protobuf) timestamp = datetime.now(timezone.utc).isoformat(timespec="milliseconds") self._version = timestamp.replace(":", "-").replace("+00:00", "Z") self._experiment = litlogger.Experiment( name=self._experiment_name, teamspace=self._teamspace, metadata={k: str(v) for k, v in self._metadata.items()}, store_step=True, store_created_at=True, log_dir=self.log_dir, save_logs=self._save_logs, ) self._experiment.print_url() return self._experiment def _require_experiment(self) -> "Experiment": experiment = self.experiment if experiment is None: raise RuntimeError("Experiment is not initialized") return experiment @property @rank_zero_only def url(self) -> str: return self._require_experiment().url # ────────────────────────────────────────────────────────────────────────────── # Override methods from Logger # ──────────────────────────────────────────────────────────────────────────────
[docs] @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" # Ensure experiment is initialized experiment = self._require_experiment() self._step = self._step + 1 if step is None else step metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) metrics = {k: v.item() if isinstance(v, Tensor) else v for k, v in metrics.items()} for key, value in metrics.items(): experiment[key].append(value, step=self._step)
[docs] @override @rank_zero_only def log_hyperparams( self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None, ) -> None: """Log hyperparams.""" if isinstance(params, Namespace): params = params.__dict__ experiment = self._require_experiment() for key, value in params.items(): experiment[key] = str(value)
[docs] @override @rank_zero_only def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: warnings.warn("LitLogger does not support `log_graph`", UserWarning, stacklevel=2)
[docs] @override @rank_zero_only def save(self) -> None: pass
[docs] @override @rank_zero_only def finalize(self, status: Optional[str] = None) -> None: if self._experiment is not None: # log checkpoints as artifacts before finalizing if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) self._experiment.finalize(status)
# ────────────────────────────────────────────────────────────────────────────── # Public methods # ──────────────────────────────────────────────────────────────────────────────
[docs] @rank_zero_only def log_metadata( self, params: Union[dict[str, Any], Namespace], ) -> None: """Log hyperparams.""" if isinstance(params, Namespace): params = params.__dict__ experiment = self._require_experiment() for key, value in params.items(): experiment[key] = str(value)
[docs] @rank_zero_only def log_model( self, model: Any, staging_dir: Optional[str] = None, verbose: bool = False, version: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, ) -> None: """Save and upload a model object to cloud storage. Args: model: The model object to save and upload (e.g., torch.nn.Module). staging_dir: Optional local directory for staging the model before upload. verbose: Whether to show progress bar during upload. version: Optional version string for the model. metadata: Optional metadata dictionary to store with the model. """ from litlogger import Model self._require_experiment()[self._model_key()] = Model( model, version=self._model_version(version, self._step), metadata=cast(Optional[dict[str, str]], metadata), staging_dir=staging_dir, )
[docs] @rank_zero_only def log_model_artifact( self, path: str, verbose: bool = False, version: Optional[str] = None, ) -> None: """Upload a model file or directory to cloud storage using litmodels. Args: path: Path to the local model file or directory to upload. verbose: Whether to show progress bar during upload. Defaults to False. version: Optional version string for the model. Defaults to the experiment version. """ from litlogger import Model self._require_experiment()[self._model_key()] = Model(path, version=self._model_version(version, self._step))
[docs] @rank_zero_only def get_file(self, path: str, verbose: bool = True) -> str: """Download a file artifact from the cloud for this experiment. Args: path: Path where the file should be saved locally. verbose: Whether to print a confirmation message after download. Defaults to True. Returns: str: The local path where the file was saved. """ file = cast(Any, self._require_experiment()[self._default_artifact_key(path)]) return file.save(path)
[docs] @rank_zero_only def log_file(self, path: str) -> None: """Log a file as an artifact to the Lightning platform. The file will be logged in the Teamspace drive, under a folder identified by the experiment name. Args: path: Path to the file to log. Example:: logger = LitLogger(...) logger.log_file('config.yaml') """ from litlogger import File self._require_experiment()[self._default_artifact_key(path)] = File(path)
# ────────────────────────────────────────────────────────────────────────────── # Callback methods # ──────────────────────────────────────────────────────────────────────────────
[docs] def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: """Called after a checkpoint is saved. Logs checkpoints as artifacts if enabled. """ if self._log_model is False: return if checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) else: self._checkpoint_callback = checkpoint_callback
# ────────────────────────────────────────────────────────────────────────────── # Private methods # ────────────────────────────────────────────────────────────────────────────── def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: """Find new checkpoints from the callback and log them as model artifacts.""" checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) for timestamp, path_ckpt, _score, _tag in checkpoints: experiment = self._require_experiment() checkpoint_key = self._checkpoint_name or experiment.name checkpoint_step = getattr(checkpoint_callback, "_last_global_step_saved", None) from litlogger import Model experiment[checkpoint_key] = Model(path_ckpt, version=self._model_version(None, checkpoint_step)) # remember logged models - timestamp needed in case filename didn't change self._logged_model_time[path_ckpt] = timestamp