脚本虽然简单,但能给人一种豁然开朗的感觉。在动作识别中很有用。
# 效果展示脚本
import torch
import numpy as np
from network import C3D_model
import cv2
torch.backends.cudnn.benchmark = True
def center_crop(frame):
frame = frame[8:120, 30:142, :]
return np.array(frame).astype(np.uint8)
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device being used:", device)
# 类别对应
with open('./dataloaders/ucf_labels.txt', 'r') as f:
class_names = f.readlines()
f.close()
# init model
model = C3D_model.C3D(num_classes=101)
# 导入模型
checkpoint = torch.load(
'.\\run\\run_0\\models\\C3D-ucf101_epoch-99.pth.tar',
map_location=lambda storage, loc: storage)
"""
state_dict = model.state_dict()
for k1, k2 in zip(state_dict.keys(), checkpoint.keys()):
state_dict[k1] = checkpoint[k2]
model.load_state_dict(state_dict)
"""
model.load_state_dict(checkpoint['state_dict']) # 模型参数
# optimizer.load_state_dict(checkpoint['opt_dict'])#优化参数
model.to(device)
model.eval()
# read video
# video = '.\\test_video\\v_ApplyEyeMakeup_g01_c01.avi'
video = '.\\MyC3dChange\\merged_video.avi'
cap = cv2.VideoCapture(video)
retaining = True
clip = []
while retaining:
retaining, frame = cap.read()
if not retaining and frame is None:
continue
tmp_ = center_crop(cv2.resize(frame, (171, 128)))
tmp = tmp_ - np.array([[[90.0, 98.0, 102.0]]])
clip.append(tmp)
# 每隔16帧取一帧
if len(clip) == 16:
inputs = np.array(clip).astype(np.float32)
inputs = np.expand_dims(inputs, axis=0)
inputs = np.transpose(inputs, (0, 4, 1, 2, 3))
inputs = torch.from_numpy(inputs)
inputs = torch.autograd.Variable(inputs, requires_grad=False).to(device)
with torch.no_grad():
outputs = model.forward(inputs)
probs = torch.nn.Softmax(dim=1)(outputs)
label = torch.max(probs, 1)[1].detach().cpu().numpy()[0]
cv2.putText(frame, class_names[label].split(' ')[-1].strip(), (20, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(0, 0, 255), 1)
cv2.putText(frame, "prob: %.4f" % probs[0][label], (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(0, 0, 255), 1)
clip.pop(0)
cv2.imshow('result', frame)
cv2.waitKey(30)
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()