248 lines
7.5 KiB
Python
248 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Split the Kaggle wood defects COCO dataset into train/valid/test splits.
|
|
Creates both COCO format and YOLO format annotations.
|
|
|
|
Usage:
|
|
python split_coco_dataset.py --input bbox_coco_dataset.json --images IMAGE/ --output dataset_split
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
|
|
def coco_to_yolo_bbox(bbox: List[float], img_width: int, img_height: int) -> List[float]:
|
|
"""
|
|
Convert COCO bbox [x, y, width, height] to YOLO format [x_center, y_center, width, height].
|
|
All values normalized to [0, 1].
|
|
|
|
Args:
|
|
bbox: COCO format [x_min, y_min, width, height]
|
|
img_width: Image width in pixels
|
|
img_height: Image height in pixels
|
|
|
|
Returns:
|
|
YOLO format [x_center, y_center, width, height] normalized
|
|
"""
|
|
x_min, y_min, width, height = bbox
|
|
|
|
# Calculate center coordinates
|
|
x_center = (x_min + width / 2) / img_width
|
|
y_center = (y_min + height / 2) / img_height
|
|
|
|
# Normalize width and height
|
|
norm_width = width / img_width
|
|
norm_height = height / img_height
|
|
|
|
return [x_center, y_center, norm_width, norm_height]
|
|
|
|
|
|
def split_coco_dataset(
|
|
input_json: Path,
|
|
images_dir: Path,
|
|
output_dir: Path,
|
|
train_split: float = 0.8,
|
|
valid_split: float = 0.1,
|
|
seed: int = 42
|
|
):
|
|
"""
|
|
Split COCO dataset into train/valid/test splits.
|
|
|
|
Args:
|
|
input_json: Path to input COCO JSON file
|
|
images_dir: Directory containing all images
|
|
output_dir: Output directory for splits
|
|
train_split: Fraction for training (default 0.8)
|
|
valid_split: Fraction for validation (default 0.1)
|
|
seed: Random seed for reproducibility
|
|
"""
|
|
# Load COCO data
|
|
with input_json.open('r') as f:
|
|
data = json.load(f)
|
|
|
|
images = data['images']
|
|
annotations = data['annotations']
|
|
categories = data['categories']
|
|
|
|
# Set random seed for reproducibility
|
|
random.seed(seed)
|
|
|
|
# Shuffle images
|
|
random.shuffle(images)
|
|
|
|
# Calculate split sizes
|
|
n_images = len(images)
|
|
n_train = int(n_images * train_split)
|
|
n_valid = int(n_images * valid_split)
|
|
n_test = n_images - n_train - n_valid
|
|
|
|
print(f"Total images: {n_images}")
|
|
print(f"Train: {n_train}, Valid: {n_valid}, Test: {n_test}")
|
|
|
|
# Create splits
|
|
splits = {
|
|
'train': images[:n_train],
|
|
'valid': images[n_train:n_train + n_valid],
|
|
'test': images[n_train + n_valid:]
|
|
}
|
|
|
|
# Create output directories
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create category ID to index mapping (YOLO uses 0-indexed categories)
|
|
category_id_to_idx = {cat['id']: idx for idx, cat in enumerate(categories)}
|
|
|
|
# Create image_id to image info mapping
|
|
image_info = {img['id']: img for img in images}
|
|
|
|
for split_name, split_images in splits.items():
|
|
split_dir = output_dir / split_name
|
|
split_dir.mkdir(exist_ok=True)
|
|
|
|
# Create labels directory for YOLO format
|
|
labels_dir = split_dir / 'labels'
|
|
labels_dir.mkdir(exist_ok=True)
|
|
|
|
# Get image IDs for this split
|
|
split_image_ids = {img['id'] for img in split_images}
|
|
|
|
# Filter annotations for this split
|
|
split_annotations = [
|
|
ann for ann in annotations
|
|
if ann['image_id'] in split_image_ids
|
|
]
|
|
|
|
# Create COCO data for this split
|
|
split_data = {
|
|
'images': split_images,
|
|
'annotations': split_annotations,
|
|
'categories': categories
|
|
}
|
|
|
|
# Save COCO JSON
|
|
json_path = split_dir / '_annotations.coco.json'
|
|
with json_path.open('w') as f:
|
|
json.dump(split_data, f, indent=2)
|
|
|
|
# Group annotations by image_id for YOLO format
|
|
annotations_by_image: Dict[int, List] = {}
|
|
for ann in split_annotations:
|
|
img_id = ann['image_id']
|
|
if img_id not in annotations_by_image:
|
|
annotations_by_image[img_id] = []
|
|
annotations_by_image[img_id].append(ann)
|
|
|
|
# Copy images and create YOLO labels
|
|
copied = 0
|
|
for img in split_images:
|
|
src_path = images_dir / img['file_name']
|
|
dst_path = split_dir / img['file_name']
|
|
|
|
if src_path.exists():
|
|
shutil.copy2(src_path, dst_path)
|
|
copied += 1
|
|
|
|
# Create YOLO format label file
|
|
img_id = img['id']
|
|
label_file = labels_dir / f"{Path(img['file_name']).stem}.txt"
|
|
|
|
with label_file.open('w') as f:
|
|
if img_id in annotations_by_image:
|
|
for ann in annotations_by_image[img_id]:
|
|
# Convert COCO bbox to YOLO format
|
|
yolo_bbox = coco_to_yolo_bbox(
|
|
ann['bbox'],
|
|
img['width'],
|
|
img['height']
|
|
)
|
|
|
|
# Get category index
|
|
cat_idx = category_id_to_idx[ann['category_id']]
|
|
|
|
# Write YOLO format: class x_center y_center width height
|
|
f.write(f"{cat_idx} {yolo_bbox[0]:.6f} {yolo_bbox[1]:.6f} "
|
|
f"{yolo_bbox[2]:.6f} {yolo_bbox[3]:.6f}\n")
|
|
else:
|
|
print(f"Warning: {src_path} not found")
|
|
|
|
print(f"{split_name}: {len(split_images)} images, {len(split_annotations)} annotations, {copied} copied")
|
|
|
|
# Create data.yaml for YOLO training
|
|
data_yaml_path = output_dir / 'data.yaml'
|
|
data_yaml_content = f"""# YOLO dataset configuration
|
|
path: {output_dir.absolute()} # dataset root dir
|
|
train: train # train images (relative to 'path')
|
|
val: valid # val images (relative to 'path')
|
|
test: test # test images (relative to 'path')
|
|
|
|
# Classes
|
|
names:
|
|
"""
|
|
for idx, cat in enumerate(categories):
|
|
data_yaml_content += f" {idx}: {cat['name']}\n"
|
|
|
|
with data_yaml_path.open('w') as f:
|
|
f.write(data_yaml_content)
|
|
|
|
print(f"\nDataset split complete! Saved to: {output_dir}")
|
|
print(f"Created YOLO format labels in {output_dir}/{{train,valid,test}}/labels/")
|
|
print(f"Created data.yaml at {data_yaml_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Split COCO dataset into train/valid/test")
|
|
parser.add_argument(
|
|
"--input",
|
|
type=Path,
|
|
default="bbox_coco_dataset.json",
|
|
help="Input COCO JSON file"
|
|
)
|
|
parser.add_argument(
|
|
"--images",
|
|
type=Path,
|
|
default="IMAGE",
|
|
help="Directory containing images"
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=Path,
|
|
default="dataset_split",
|
|
help="Output directory for splits"
|
|
)
|
|
parser.add_argument(
|
|
"--train-split",
|
|
type=float,
|
|
default=0.8,
|
|
help="Training split fraction"
|
|
)
|
|
parser.add_argument(
|
|
"--valid-split",
|
|
type=float,
|
|
default=0.1,
|
|
help="Validation split fraction"
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=42,
|
|
help="Random seed"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
split_coco_dataset(
|
|
input_json=args.input,
|
|
images_dir=args.images,
|
|
output_dir=args.output,
|
|
train_split=args.train_split,
|
|
valid_split=args.valid_split,
|
|
seed=args.seed
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |