多分类自定义采样比例
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
# 假设你有一个自定义的数据集类
class CustomDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.dataset = ImageFolder(data_dir, transform=transform)
self.class_weights = self.calculate_class_weights()
def calculate_class_weights(self):
# 计算每个类别的样本权重,可以根据不同的策略进行调整
class_counts = torch.tensor([self.dataset.targets.count(i) for i in range(len(self.dataset.classes))])
class_weights = 1.0 / class_counts
return class_weights
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
# 数据集目录
data_dir = "path/to/your/dataset"
# 定义图像转换
transform = transforms.Compose([
transforms.Resize((224, 224)),
tra