要将新的异常检测模型集成到Anomalib中,可以按照以下步骤进行操作:
1 Create a new sub-package
在anomalib/models中创建的一个新目录,用于存储与模型相关的文件。
./anomalib/models/<new-model>
├── __init__.py
├── config.yaml
├── torch_model.py
├── lightning_model.py
├── loss.py # OPTIONAL
├── anomaly_map.py # OPTIONAL
└── README.md
3 Create a config.yaml file.
config.yaml文件存储了所有的配置信息,包括数据和优化选项。下面是一个示例的yaml文件:
dataset:
name: mvtec #options: [mvtec, btech, folder]
format: mvtec
...
model:
name: patchcore
backbone: wide_resnet50_2
...
metrics:
image:
- F1Score
...
visualization:
show_images: False # show images on the screen
...
# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
...
4 Create a torch_model.py file.
torch_model.py文件包含了继承自torch.nn.Module的torch模型实现,定义了模型的架构并执行基本的前向传播。将模型存储在一个独立的torch_model.py文件中的优势是,模型与anomalib的其他实现解耦,也可以在库之外使用。基本实现如下所示:
class NewModelModel(nn.Module):
"""New Model Module."""
def __init__(self):
pass
def forward(self, x):
pass
5 Create a lightning_model.py file.
lightning_model.py模块包含了继承自AnomalModule的lightning模型实现,AnomalModule已经具有与anomalib相关的属性和方法。用户不需要担心样板代码,只需要实现算法的训练和验证逻辑即可。
class NewModel(AnomalyModule):
"""PL Lightning Module for the New Model."""
def __init__(self):
super().__init__()
pass
def training_step(self, batch):
pass
...
def validation_step(self, batch):
pass
6 [OPTIONAL] Create a loss.py file.
如果算法需要自定义的复杂损失函数,则需要实现loss.py文件。loss.py文件包含torch.nn.Module类实现的子类。然后,lightning模块将使用这个损失函数。
class NewModelLoss(nn.Module):
"""NewModel Loss."""
def forward(self) -> Tensor:
"""Calculate the NewModel loss."""
pass
7 [OPTIONAL] Create an anomaly_map.py file.
如果算法支持分割,那么可以实现这个模块,anomaly_map.py模块根据算法的能力以便逐像素地预测异常的位置。
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""
def __init__(self, input_size: Union[ListConfig, Tuple]):
pass
def forward(self, x: Tensor) -> Tensor:
"""Generate Anomaly Heatmap."""
...
return anomaly_map
8 Create a README.md file.
编写readme
# Name of the Model
## Description
Brief description of the paper.
## Architecture
A diagram showing the high-level overview.
## Usage
python tools/train.py --model <newmodel>
## Benchmark
Benchmark results on MVTec categories.