317 lines
9.4 KiB
Python
317 lines
9.4 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Training script for wood knot detection models.
|
||
|
|
|
||
|
|
Supports: RF-DETR, RT-DETR, YOLOv6, YOLOX
|
||
|
|
All models are MIT/Apache 2.0 licensed - free for commercial use.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import shutil
|
||
|
|
import subprocess
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
|
||
|
|
def prepare_training_dataset(
|
||
|
|
images_dir: Path,
|
||
|
|
annotations_file: Path,
|
||
|
|
output_dir: Path,
|
||
|
|
train_split: float = 0.8,
|
||
|
|
valid_split: float = 0.1
|
||
|
|
) -> str:
|
||
|
|
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Load annotations
|
||
|
|
with annotations_file.open() as f:
|
||
|
|
annotations = json.load(f)
|
||
|
|
|
||
|
|
# Get all image files
|
||
|
|
image_files = []
|
||
|
|
for img_name in annotations.keys():
|
||
|
|
img_path = images_dir / img_name
|
||
|
|
if img_path.exists():
|
||
|
|
image_files.append(img_name)
|
||
|
|
|
||
|
|
if not image_files:
|
||
|
|
return "❌ No annotated images found"
|
||
|
|
|
||
|
|
# Shuffle and split
|
||
|
|
import random
|
||
|
|
random.seed(42)
|
||
|
|
random.shuffle(image_files)
|
||
|
|
|
||
|
|
n_total = len(image_files)
|
||
|
|
n_train = int(n_total * train_split)
|
||
|
|
n_valid = int(n_total * valid_split)
|
||
|
|
n_test = n_total - n_train - n_valid
|
||
|
|
|
||
|
|
splits = {
|
||
|
|
'train': image_files[:n_train],
|
||
|
|
'valid': image_files[n_train:n_train + n_valid],
|
||
|
|
'test': image_files[n_train + n_valid:]
|
||
|
|
}
|
||
|
|
|
||
|
|
# Create directories
|
||
|
|
for split in ['train', 'valid', 'test']:
|
||
|
|
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||
|
|
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Copy images and create labels
|
||
|
|
for split, img_names in splits.items():
|
||
|
|
for img_name in img_names:
|
||
|
|
# Copy image
|
||
|
|
src_img = images_dir / img_name
|
||
|
|
dst_img = output_dir / split / 'images' / img_name
|
||
|
|
shutil.copy2(src_img, dst_img)
|
||
|
|
|
||
|
|
# Convert annotations to YOLO format
|
||
|
|
boxes = annotations.get(img_name, [])
|
||
|
|
if boxes:
|
||
|
|
img_width, img_height = get_image_size(src_img)
|
||
|
|
yolo_lines = []
|
||
|
|
for box in boxes:
|
||
|
|
if isinstance(box, dict) and "bbox" in box:
|
||
|
|
x1, y1, x2, y2 = box["bbox"]
|
||
|
|
else:
|
||
|
|
x1, y1, x2, y2 = box
|
||
|
|
# Convert to YOLO format (normalized center x, center y, width, height)
|
||
|
|
x_center = (x1 + x2) / 2 / img_width
|
||
|
|
y_center = (y1 + y2) / 2 / img_height
|
||
|
|
width = (x2 - x1) / img_width
|
||
|
|
height = (y2 - y1) / img_height
|
||
|
|
yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||
|
|
|
||
|
|
# Save label file
|
||
|
|
label_file = output_dir / split / 'labels' / img_name.replace('.jpg', '.txt').replace('.png', '.txt')
|
||
|
|
with label_file.open('w') as f:
|
||
|
|
f.write('\n'.join(yolo_lines))
|
||
|
|
|
||
|
|
# Create data.yaml for YOLO models
|
||
|
|
data_yaml = f"""train: {output_dir}/train/images
|
||
|
|
val: {output_dir}/valid/images
|
||
|
|
test: {output_dir}/test/images
|
||
|
|
|
||
|
|
nc: 1
|
||
|
|
names: ['knot']
|
||
|
|
"""
|
||
|
|
|
||
|
|
with (output_dir / 'data.yaml').open('w') as f:
|
||
|
|
f.write(data_yaml)
|
||
|
|
|
||
|
|
return f"✓ Dataset prepared: {n_train} train, {n_valid} valid, {n_test} test images"
|
||
|
|
|
||
|
|
|
||
|
|
def get_image_size(img_path: Path) -> tuple[int, int]:
|
||
|
|
"""Get image dimensions."""
|
||
|
|
from PIL import Image
|
||
|
|
with Image.open(img_path) as img:
|
||
|
|
return img.size
|
||
|
|
|
||
|
|
|
||
|
|
def train_rfdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||
|
|
"""Train RF-DETR model."""
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Import and run training
|
||
|
|
try:
|
||
|
|
from train_rfdetr import main as train_main
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# Set up arguments
|
||
|
|
sys.argv = [
|
||
|
|
'train_rfdetr.py',
|
||
|
|
'--data', str(dataset_dir / 'data.yaml'),
|
||
|
|
'--output', str(output_dir),
|
||
|
|
'--model', model_size,
|
||
|
|
'--epochs', str(epochs),
|
||
|
|
'--batch', str(batch_size),
|
||
|
|
'--lr', str(lr)
|
||
|
|
]
|
||
|
|
|
||
|
|
train_main()
|
||
|
|
return "✓ RF-DETR training completed"
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return f"❌ RF-DETR training failed: {e}"
|
||
|
|
|
||
|
|
|
||
|
|
def train_rtdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||
|
|
"""Train RT-DETR model."""
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
try:
|
||
|
|
from train_rtdetr import main as train_main
|
||
|
|
import sys
|
||
|
|
|
||
|
|
sys.argv = [
|
||
|
|
'train_rtdetr.py',
|
||
|
|
'--data', str(dataset_dir / 'data.yaml'),
|
||
|
|
'--output', str(output_dir),
|
||
|
|
'--model', model_size,
|
||
|
|
'--epochs', str(epochs),
|
||
|
|
'--batch', str(batch_size),
|
||
|
|
'--lr', str(lr)
|
||
|
|
]
|
||
|
|
|
||
|
|
train_main()
|
||
|
|
return "✓ RT-DETR training completed"
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return f"❌ RT-DETR training failed: {e}"
|
||
|
|
|
||
|
|
|
||
|
|
def train_yolov6(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||
|
|
"""Train YOLOv6 model."""
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
try:
|
||
|
|
from train_yolov6 import main as train_main
|
||
|
|
import sys
|
||
|
|
|
||
|
|
sys.argv = [
|
||
|
|
'train_yolov6.py',
|
||
|
|
'--data', str(dataset_dir / 'data.yaml'),
|
||
|
|
'--output', str(output_dir),
|
||
|
|
'--model', model_size,
|
||
|
|
'--epochs', str(epochs),
|
||
|
|
'--batch', str(batch_size),
|
||
|
|
'--lr', str(lr)
|
||
|
|
]
|
||
|
|
|
||
|
|
train_main()
|
||
|
|
return "✓ YOLOv6 training completed"
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return f"❌ YOLOv6 training failed: {e}"
|
||
|
|
|
||
|
|
|
||
|
|
def train_yolox(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||
|
|
"""Train YOLOX model."""
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
try:
|
||
|
|
from train_yolox import main as train_main
|
||
|
|
import sys
|
||
|
|
|
||
|
|
sys.argv = [
|
||
|
|
'train_yolox.py',
|
||
|
|
'--data', str(dataset_dir / 'data.yaml'),
|
||
|
|
'--output', str(output_dir),
|
||
|
|
'--model', model_size,
|
||
|
|
'--epochs', str(epochs),
|
||
|
|
'--batch', str(batch_size),
|
||
|
|
'--lr', str(lr)
|
||
|
|
]
|
||
|
|
|
||
|
|
train_main()
|
||
|
|
return "✓ YOLOX training completed"
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return f"❌ YOLOX training failed: {e}"
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Train object detection models for wood knot detection")
|
||
|
|
parser.add_argument(
|
||
|
|
'--framework',
|
||
|
|
choices=['rf-detr', 'rt-detr', 'yolov6', 'yolox'],
|
||
|
|
default='rt-detr',
|
||
|
|
help='Model framework to train'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--dataset',
|
||
|
|
type=Path,
|
||
|
|
required=True,
|
||
|
|
help='Path to prepared dataset directory'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--output',
|
||
|
|
type=Path,
|
||
|
|
default=Path('runs/training'),
|
||
|
|
help='Output directory for trained model'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--model-size',
|
||
|
|
choices=['nano', 'small', 'medium', 'base'],
|
||
|
|
default='small',
|
||
|
|
help='Model size/variant'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--epochs',
|
||
|
|
type=int,
|
||
|
|
default=20,
|
||
|
|
help='Number of training epochs'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--batch-size',
|
||
|
|
type=int,
|
||
|
|
default=4,
|
||
|
|
help='Batch size for training'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--lr',
|
||
|
|
type=float,
|
||
|
|
default=1e-4,
|
||
|
|
help='Learning rate'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--prepare-dataset',
|
||
|
|
action='store_true',
|
||
|
|
help='Prepare dataset from annotations first'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--images-dir',
|
||
|
|
type=Path,
|
||
|
|
default=Path('IMAGE'),
|
||
|
|
help='Images directory (for --prepare-dataset)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--annotations',
|
||
|
|
type=Path,
|
||
|
|
default=Path('annotations.json'),
|
||
|
|
help='Annotations file (for --prepare-dataset)'
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if args.prepare_dataset:
|
||
|
|
print("Preparing dataset...")
|
||
|
|
result = prepare_training_dataset(
|
||
|
|
args.images_dir,
|
||
|
|
args.annotations,
|
||
|
|
args.dataset
|
||
|
|
)
|
||
|
|
print(result)
|
||
|
|
if "❌" in result:
|
||
|
|
return
|
||
|
|
|
||
|
|
print(f"Training {args.framework.upper()} model...")
|
||
|
|
print(f"Dataset: {args.dataset}")
|
||
|
|
print(f"Output: {args.output}")
|
||
|
|
print(f"Model size: {args.model_size}")
|
||
|
|
print(f"Epochs: {args.epochs}")
|
||
|
|
print(f"Batch size: {args.batch_size}")
|
||
|
|
print(f"Learning rate: {args.lr}")
|
||
|
|
|
||
|
|
# Train based on framework
|
||
|
|
if args.framework == 'rf-detr':
|
||
|
|
result = train_rfdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||
|
|
elif args.framework == 'rt-detr':
|
||
|
|
result = train_rtdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||
|
|
elif args.framework == 'yolov6':
|
||
|
|
result = train_yolov6(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||
|
|
elif args.framework == 'yolox':
|
||
|
|
result = train_yolox(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||
|
|
|
||
|
|
print(result)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()</content>
|
||
|
|
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/train_model.py
|