1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| import os import random import shutil from tqdm import tqdm
# 清空文件夹 def clear_dir(path): if os.path.exists(path): shutil.rmtree(path) os.makedirs(path)
def ensure_dirs(dir_list): for d in dir_list: os.makedirs(d, exist_ok=True)
def get_valid_filenames(image_dir, label_dir, img_ext='.jpg', label_ext='.txt'): image_names = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.endswith(img_ext)} label_names = {os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith(label_ext)} return sorted(list(image_names & label_names))
def split_dataset(filenames, train_ratio, valid_ratio): total = len(filenames) train_cnt = int(total * train_ratio) valid_cnt = int(total * valid_ratio) return ( filenames[:train_cnt], filenames[train_cnt:train_cnt+valid_cnt], filenames[train_cnt+valid_cnt:] )
def copy_file_with_progress(src, dst): size = os.path.getsize(src) with open(src, 'rb') as fsrc, open(dst, 'wb') as fdst: for chunk in iter(lambda: fsrc.read(1024*1024), b''): fdst.write(chunk)
def batch_copy(filenames, image_dir, label_dir, img_ext, label_ext, out_img_dir, out_label_dir, pbar): for fname in filenames: src_img = os.path.join(image_dir, fname + img_ext) dst_img = os.path.join(out_img_dir, fname + img_ext) src_lbl = os.path.join(label_dir, fname + label_ext) dst_lbl = os.path.join(out_label_dir, fname + label_ext) copy_file_with_progress(src_img, dst_img) copy_file_with_progress(src_lbl, dst_lbl) pbar.update(1)
def main(): # random.seed(123)
# 配置参数 - 只需要在这里修改路径 output_dir = '/run/media/boqi/存储/AI/YOLO/dataset/klbq/' root_dir = '/run/media/boqi/存储/AI/YOLO/out/klbq/'
image_dir = os.path.join(root_dir, 'Images') label_dir = os.path.join(root_dir, 'labels') img_ext, label_ext = '.jpg', '.txt' train_ratio, valid_ratio, test_ratio = 0.8, 0.1, 0.1
# 检查目录是否存在 if not os.path.isdir(image_dir) or not os.path.isdir(label_dir): raise FileNotFoundError("图像或标签目录不存在,请检查路径!")
# 先清空输出目录 clear_dir(output_dir)
filenames = get_valid_filenames(image_dir, label_dir, img_ext, label_ext) random.shuffle(filenames) train_files, valid_files, test_files = split_dataset(filenames, train_ratio, valid_ratio)
# 输出各部分图片数量 print(f"train: {len(train_files)} 张") print(f"valid: {len(valid_files)} 张") print(f"test: {len(test_files)} 张")
# 输出目录 dirs = { 'train_img': os.path.join(output_dir, 'train', 'images'), 'train_lbl': os.path.join(output_dir, 'train', 'labels'), 'valid_img': os.path.join(output_dir, 'valid', 'images'), 'valid_lbl': os.path.join(output_dir, 'valid', 'labels'), 'test_img': os.path.join(output_dir, 'test', 'images'), 'test_lbl': os.path.join(output_dir, 'test', 'labels'), }
ensure_dirs(dirs.values()) total_files = len(train_files) + len(valid_files) + len(test_files) with tqdm(total=total_files, desc='总进度', unit='文件') as pbar: batch_copy(train_files, image_dir, label_dir, img_ext, label_ext, dirs['train_img'], dirs['train_lbl'], pbar) batch_copy(valid_files, image_dir, label_dir, img_ext, label_ext, dirs['valid_img'], dirs['valid_lbl'], pbar) batch_copy(test_files, image_dir, label_dir, img_ext, label_ext, dirs['test_img'], dirs['test_lbl'], pbar)
if __name__ == '__main__': main()
|