import json
from PIL import Image
import torch
from torchvision import transforms
model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
img = Image.open('panda.jpg')
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(1000)]
tfms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = tfms(img).unsqueeze(0)
model.eval()
with torch.no_grad():
logits = model(img)
preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()
for idx in preds:
label = labels_map[idx]
prob = torch.softmax(logits, dim=1)[0, idx].item()
print('{:<75} ({:.2f}%)'.format(label, prob*100))