四、关于get_experiment_logger
get_experiment_logger里面放的是什么呢?
def get_experiment_logger(
config: DictConfig | ListConfig,
) -> Logger | Iterable[Logger] | bool:
"""Return a logger based on the choice of logger in the config file.
Args:
config (DictConfig): config.yaml file for the corresponding anomalib model.
Raises:
ValueError: for any logger types apart from false and tensorboard
Returns:
Logger | Iterable[Logger] | bool]: Logger
"""
logger.info("Loading the experiment logger(s)")
# TODO remove when logger is deprecated from project
if "logger" in config.project.keys():
warnings.warn(
"'logger' key will be deprecated from 'project' section of the config file."
" Please use the logging section in config file.",
DeprecationWarning,
)
if "logging" not in config:
config.logging = {"logger": config.project.logger, "log_graph": False}
else:
config.logging.logger = config.project.logger
if config.logging.logger in (None, False):
return False
print("-------------------hahahhahhah")
print(config.logging.logger)
print("-------------------hahahhahhah-end")
logger_list: list[Logger] = []
if isinstance(config.logging.logger, str):
config.logging.logger = [config.logging.logger]
print("------------------------gao1")
for experiment_logger in config.logging.logger:
print("-------------------------gao2")
print(experiment_logger)
if experiment_logger == "tensorboard":
logger_list.append(
AnomalibTensorBoardLogger(
name="Tensorboard Logs",
save_dir=os.path.join(config.project.path, "logs"),
log_graph=config.logging.log_graph,
)
)
elif experiment_logger == "wandb":
wandb_logdir = os.path.join(config.project.path, "logs")
Path(wandb_logdir).mkdir(parents=True, exist_ok=True)
name = (
config.model.name
if "category" not in config.dataset.keys()
else f"{config.dataset.category} {config.model.name}"
)
logger_list.append(
AnomalibWandbLogger(
project=config.dataset.name,
name=name,
save_dir=wandb_logdir,
)
)
elif experiment_logger == "comet":
comet_logdir = os.path.join(config.project.path, "logs")
Path(comet_logdir).mkdir(parents=True, exist_ok=True)
run_name = (
config.model.name
if "category" not in config.dataset.keys()
else f"{config.dataset.category} {config.model.name}"
)
logger_list.append(
AnomalibCometLogger(project_name=config.dataset.name, experiment_name=run_name, save_dir=comet_logdir)
)
elif experiment_logger == "csv":
logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
else:
raise UnknownLogger(
f"Unknown logger type: {config.logging.logger}. "
f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
f"To disable the logger, set `project.logger` to `false`."
)
print("-------------------------gao3")
print(logger_list)
print("-------------------------gao3-end")
return logger_list
你看最后这几行,我打印出来的logger_list,是一个[]
白折腾了,呵呵。
五、关于get_callbacks
callbacks = get_callbacks(config) 看看里面都是啥:
def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
"""Return base callbacks for all the lightning models.
Args:
config (DictConfig): Model config
Return:
(list[Callback]): List of callbacks.
"""
logger.info("Loading the callbacks")
callbacks: list[Callback] = []
monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode
checkpoint = ModelCheckpoint(
dirpath=os.path.join(config.project.path, "weights"),
filename="model",
monitor=monitor_metric,
mode=monitor_mode,
auto_insert_metric_name=False,
)
callbacks.extend([checkpoint, TimerCallback()])
if "resume_from_checkpoint" in config.trainer.keys() and config.trainer.resume_from_checkpoint is not None:
load_model = LoadModelCallback(config.trainer.resume_from_checkpoint)
callbacks.append(load_model)
# Add post-processing configurations to AnomalyModule.
image_threshold = (
config.metrics.threshold.manual_image if "manual_image" in config.metrics.threshold.keys() else None
)
pixel_threshold = (
config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
)
post_processing_callback = PostProcessingConfigurationCallback(
threshold_method=config.metrics.threshold.method,
manual_image_threshold=image_threshold,
manual_pixel_threshold=pixel_threshold,
)
callbacks.append(post_processing_callback)
# Add metric configuration to the model via MetricsConfigurationCallback
metrics_callback = MetricsConfigurationCallback(
config.dataset.task,
config.metrics.get("image", None),
config.metrics.get("pixel", None),
)
callbacks.append(metrics_callback)
if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
if config.model.normalization_method == "cdf":
if config.model.name in ("padim", "stfpm"):
if "nncf" in config.optimization and config.optimization.nncf.apply:
raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
callbacks.append(CdfNormalizationCallback())
else:
raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
elif config.model.normalization_method == "min_max":
callbacks.append(MinMaxNormalizationCallback())
else:
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")
add_visualizer_callback(callbacks, config)
if "optimization" in config.keys():
if "nncf" in config.optimization and config.optimization.nncf.apply:
# NNCF wraps torch's jit which conflicts with kornia's jit calls.
# Hence, nncf is imported only when required
nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
nncf_callback = getattr(nncf_module, "NNCFCallback")
nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
callbacks.append(
nncf_callback(
config=nncf_config,
export_dir=os.path.join(config.project.path, "compressed"),
)
)
if config.optimization.export_mode is not None:
from .export import ( # pylint: disable=import-outside-toplevel
ExportCallback,
)
logger.info("Setting model export to %s", config.optimization.export_mode)
callbacks.append(
ExportCallback(
input_size=config.model.input_size,
dirpath=config.project.path,
filename="model",
export_mode=ExportMode(config.optimization.export_mode),
)
)
else:
warnings.warn(f"Export option: {config.optimization.export_mode} not found. Defaulting to no model export")
# Add callback to log graph to loggers
if config.logging.log_graph not in (None, False):
callbacks.append(GraphLogger())
print("-------------gao callbacks")
print(callbacks)
print("-------------gao callbacks-end")
return callbacks
看看我最后打印出来的callbacks,是啥样的?
六、关于Trainer
终于来到了最总要的地方了:
Trainer的第一个参数,是下面这样婶儿的:
第二个参数,前面我们说了,就是[]
第三个参数,就是上面那一堆callbacks