源码来源
https://www.cnblogs.com/zkweb/p/14403833.html
原文作者的文章很nice,受益匪浅,我画蛇添足一下,对数据读取部分的代码做一个更详细地解析。
def prepare():
"""准备训练"""
# 数据集转换到 tensor 以后会保存在 data 文件夹下
if not os.path.isdir("data"):
os.makedirs("data")
# 加载图片和图片对应的区域与分类列表
# { (路径, 是否左右翻转): [ 区域与分类, 区域与分类, .. ] }
# 同一张图片左右翻转可以生成一个新的数据,让数据量翻倍
box_map = defaultdict(lambda: []) ##defaultdict可以在没key的时候,查询key的时候返回的是[]
for filename in os.listdir(DATASET_1_IMAGE_DIR):
# 从第一个数据集加载
xml_path = os.path.join(DATASET_1_ANNOTATION_DIR, filename.split(".")[0] + ".xml")
if not os.path.isfile(xml_path):
continue
tree = ET.ElementTree(file=xml_path)
objects = tree.findall("object")
path = os.path.join(DATASET_1_IMAGE_DIR, filename)
for obj in objects:
class_name = obj.find("name").text
x1 = int(obj.find("bndbox/xmin").text)
x2 = int(obj.find("bndbox/xmax").text)
y1 = int(obj.find("bndbox/ymin").text)
y2 = int(obj.find("bndbox/ymax").text)
if class_name == "mask_weared_incorrect":
# 佩戴口罩不正确的样本数量太少 (只有 123),模型无法学习,这里全合并到戴口罩的样本
class_name = "with_mask"
box_map[(path, False)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING[class_name]))
box_map[(path, True)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING[class_name]))
df = pandas.read_csv(DATASET_2_BOX_CSV_PATH)
for row in df.values:
# 从第二个数据集加载,这个数据集只包含没有带口罩的图片
filename, width, height, x1, y1, x2, y2 = row[:7]
path = os.path.join(DATASET_2_IMAGE_DIR, filename)
## False和True注意上方的注释,True的话就将这张图左右翻转一下
box_map[(path, False)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING["without_mask"]))
box_map[(path, True)].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING["without_mask"]))
# 打乱数据集 (因为第二个数据集只有不戴口罩的图片)
box_list = list(box_map.items())
random.shuffle(box_list)
print(f"found {len(box_list)} images")
# 保存图片和图片对应的分类与区域列表
batch_size = 20
batch = 0
image_tensors = [] # 图片列表
result_tensors = [] # 图片对应的输出结果列表,包含 [ 是否对象中心, 区域偏移, 各个分类的可能性 ]
result_isobject_masks = [] # 各个图片的包含对象的区域在 Anchors 中的索引
result_nonobject_masks = [] # 各个图片不包含对象的区域在 Anchors 中的索引 (重叠率低于阈值的区域)
for (image_path, flip), original_boxes_labels in box_list:
with Image.open(image_path) as img_original: # 加载原始图片
sw, sh = img_original.size # 原始图片大小
if flip:
## 自定义resize_image方法,主要做的内容是用空白填充出一个符合(256,192)比例的图
## 然后再用reshape来进行缩放
## 这样子缩放的时候可以保证原图片的比例没有变化
img = resize_image(img_original.transpose(Image.FLIP_LEFT_RIGHT)) # 翻转然后缩放图片
else:
img = resize_image(img_original) # 缩放图片
image_tensors.append(image_to_tensor(img)) # 添加图片到列表
# 生成输出结果的 tensor
## size是(锚点数,1+4+分类数量)
result_tensor = torch.zeros((len(MyModel.Anchors), MyModel.AnchorOutputs), dtype=torch.float)
result_tensor[:,5] = 1 # 默认分类为 other
result_tensors.append(result_tensor)
# 包含对象的区域在 Anchors 中的索引
result_isobject_mask = []
result_isobject_masks.append(result_isobject_mask)
# 不包含对象的区域在 Anchors 中的索引
result_nonobject_mask = []
result_nonobject_masks.append(result_nonobject_mask)
# 根据真实区域定位所属的锚点,然后设置输出结果
negative_mapping = [1] * len(MyModel.Anchors)
for box_label in original_boxes_labels:
x, y, w, h, label = box_label
if flip: # 翻转坐标
x = sw - x - w
## 前面提到原图需要缩放到256,198的尺寸
## 而此时的标记框xywh是原图的,需要将其映射回缩放后的
x, y, w, h = map_box_to_resized_image((x, y, w, h), sw, sh) # 缩放实际区域
if w < 20 or h < 20:
continue # 缩放后区域过小
# 检查计算是否有问题
# child_img = img.copy().crop((x, y, x+w, y+h))
# child_img.save(f"{os.path.basename(image_path)}_{x}_{y}_{w}_{h}_{label}.png")
# 定位所属的锚点
# 要求:
# - 中心点落在锚点对应的区域中
# - 重叠率超过一定值
x_center = x + w // 2
y_center = y + h // 2
matched_anchors = []
for index, anchor in enumerate(MyModel.Anchors):
ax, ay, aw, ah = anchor
is_center = (x_center >= ax and x_center < ax + aw and
y_center >= ay and y_center < ay + ah)
iou = calc_iou(anchor, (x, y, w, h))
if is_center and iou > IOU_POSITIVE_THRESHOLD:
matched_anchors.append((index, anchor)) # 区域包含对象中心并且重叠率超过一定值
negative_mapping[index] = 0
elif iou > IOU_NEGATIVE_THRESHOLD:
negative_mapping[index] = 0 # 区域与某个对象重叠率超过一定值,不应该当作负样本
for matched_index, matched_box in matched_anchors:
# 计算区域偏移
offset = calc_box_offset(matched_box, (x, y, w, h))
# 修改输出结果的 tensor
result_tensor[matched_index] = torch.tensor((
1, # 是否对象中心
*offset, # 区域偏移
*[int(c == label) for c in range(len(CLASSES))] # 对应分类
), dtype=torch.float)
# 添加索引值
# 注意如果两个对象同时定位到相同的锚点,那么只有一个对象可以被识别,这里后面的对象会覆盖前面的对象
if matched_index not in result_isobject_mask:
result_isobject_mask.append(matched_index)
# 没有找到可识别的对象时跳过图片
if not result_isobject_mask:
image_tensors.pop()
result_tensors.pop()
result_isobject_masks.pop()
result_nonobject_masks.pop()
continue
# 添加不包含对象的区域在 Anchors 中的索引
for index, value in enumerate(negative_mapping):
if value:
result_nonobject_mask.append(index)
# 排序索引列表
result_isobject_mask.sort()
# 保存批次
if len(image_tensors) >= batch_size:
prepare_save_batch(batch, image_tensors, result_tensors,
result_isobject_masks, result_nonobject_masks)
image_tensors.clear()
result_tensors.clear()
result_isobject_masks.clear()
result_nonobject_masks.clear()
batch += 1
# 保存剩余的批次
if len(image_tensors) > 10:
prepare_save_batch(batch, image_tensors, result_tensors,
result_isobject_masks, result_nonobject_masks)
def resize_image(img):
"""缩放图片,比例不一致时填充"""
sw, sh = img.size
sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
img_new = Image.new("RGB", (sw_new, sh_new))
## 从pad_w,pad_h这个坐标开始粘贴img
img_new.paste(img, (pad_w, pad_h))
img_new = img_new.resize(IMAGE_SIZE)
return img_new
def calc_resize_parameters(sw, sh):
"""计算缩放图片的参数"""
sw_new, sh_new = sw, sh
dw, dh = IMAGE_SIZE
pad_w, pad_h = 0, 0
if sw / sh < dw / dh:
sw_new = int(dw / dh * sh)
pad_w = (sw_new - sw) // 2 # 填充左右
else:
sh_new = int(dh / dw * sw)
pad_h = (sh_new - sh) // 2 # 填充上下
return sw_new, sh_new, pad_w, pad_h
def map_box_to_resized_image(box, sw, sh):
"""把原始区域转换到缩放后的图片对应的区域"""
x, y, w, h = box
sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
scale = IMAGE_SIZE[0] / sw_new
x = int((x + pad_w) * scale)
y = int((y + pad_h) * scale)
w = int(w * scale)
h = int(h * scale)
if x + w > IMAGE_SIZE[0] or y + h > IMAGE_SIZE[1] or w == 0 or h == 0:
return 0, 0, 0, 0
return x, y, w, h
def calc_iou(rect1, rect2):
"""计算两个区域重叠部分 / 合并部分的比率 (intersection over union)"""
x1, y1, w1, h1 = rect1
x2, y2, w2, h2 = rect2
xi = max(x1, x2)
yi = max(y1, y2)
wi = min(x1+w1, x2+w2) - xi
hi = min(y1+h1, y2+h2) - yi
if wi > 0 and hi > 0: # 有重叠部分
area_overlap = wi*hi
area_all = w1*h1 + w2*h2 - area_overlap
iou = area_overlap / area_all
else: # 没有重叠部分
iou = 0
return iou
代码中双##的部分就是我的注释,原文的作者代码已经足够详细了,但是作为初学者还是有些地方看不懂,所以就再补充一点。