Files
saw_mill_knot_detection/train_model.py

317 lines
9.4 KiB
Python
Raw Permalink Normal View History

2025-12-23 18:24:40 -07:00
#!/usr/bin/env python3
"""
Training script for wood knot detection models.
Supports: RF-DETR, RT-DETR, YOLOv6, YOLOX
All models are MIT/Apache 2.0 licensed - free for commercial use.
Usage:
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
"""
import argparse
import json
import shutil
import subprocess
from pathlib import Path
from typing import Any
def prepare_training_dataset(
images_dir: Path,
annotations_file: Path,
output_dir: Path,
train_split: float = 0.8,
valid_split: float = 0.1
) -> str:
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
output_dir.mkdir(parents=True, exist_ok=True)
# Load annotations
with annotations_file.open() as f:
annotations = json.load(f)
# Get all image files
image_files = []
for img_name in annotations.keys():
img_path = images_dir / img_name
if img_path.exists():
image_files.append(img_name)
if not image_files:
return "❌ No annotated images found"
# Shuffle and split
import random
random.seed(42)
random.shuffle(image_files)
n_total = len(image_files)
n_train = int(n_total * train_split)
n_valid = int(n_total * valid_split)
n_test = n_total - n_train - n_valid
splits = {
'train': image_files[:n_train],
'valid': image_files[n_train:n_train + n_valid],
'test': image_files[n_train + n_valid:]
}
# Create directories
for split in ['train', 'valid', 'test']:
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Copy images and create labels
for split, img_names in splits.items():
for img_name in img_names:
# Copy image
src_img = images_dir / img_name
dst_img = output_dir / split / 'images' / img_name
shutil.copy2(src_img, dst_img)
# Convert annotations to YOLO format
boxes = annotations.get(img_name, [])
if boxes:
img_width, img_height = get_image_size(src_img)
yolo_lines = []
for box in boxes:
if isinstance(box, dict) and "bbox" in box:
x1, y1, x2, y2 = box["bbox"]
else:
x1, y1, x2, y2 = box
# Convert to YOLO format (normalized center x, center y, width, height)
x_center = (x1 + x2) / 2 / img_width
y_center = (y1 + y2) / 2 / img_height
width = (x2 - x1) / img_width
height = (y2 - y1) / img_height
yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
# Save label file
label_file = output_dir / split / 'labels' / img_name.replace('.jpg', '.txt').replace('.png', '.txt')
with label_file.open('w') as f:
f.write('\n'.join(yolo_lines))
# Create data.yaml for YOLO models
data_yaml = f"""train: {output_dir}/train/images
val: {output_dir}/valid/images
test: {output_dir}/test/images
nc: 1
names: ['knot']
"""
with (output_dir / 'data.yaml').open('w') as f:
f.write(data_yaml)
return f"✓ Dataset prepared: {n_train} train, {n_valid} valid, {n_test} test images"
def get_image_size(img_path: Path) -> tuple[int, int]:
"""Get image dimensions."""
from PIL import Image
with Image.open(img_path) as img:
return img.size
def train_rfdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
"""Train RF-DETR model."""
output_dir.mkdir(parents=True, exist_ok=True)
# Import and run training
try:
from train_rfdetr import main as train_main
import sys
# Set up arguments
sys.argv = [
'train_rfdetr.py',
'--data', str(dataset_dir / 'data.yaml'),
'--output', str(output_dir),
'--model', model_size,
'--epochs', str(epochs),
'--batch', str(batch_size),
'--lr', str(lr)
]
train_main()
return "✓ RF-DETR training completed"
except Exception as e:
return f"❌ RF-DETR training failed: {e}"
def train_rtdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
"""Train RT-DETR model."""
output_dir.mkdir(parents=True, exist_ok=True)
try:
from train_rtdetr import main as train_main
import sys
sys.argv = [
'train_rtdetr.py',
'--data', str(dataset_dir / 'data.yaml'),
'--output', str(output_dir),
'--model', model_size,
'--epochs', str(epochs),
'--batch', str(batch_size),
'--lr', str(lr)
]
train_main()
return "✓ RT-DETR training completed"
except Exception as e:
return f"❌ RT-DETR training failed: {e}"
def train_yolov6(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
"""Train YOLOv6 model."""
output_dir.mkdir(parents=True, exist_ok=True)
try:
from train_yolov6 import main as train_main
import sys
sys.argv = [
'train_yolov6.py',
'--data', str(dataset_dir / 'data.yaml'),
'--output', str(output_dir),
'--model', model_size,
'--epochs', str(epochs),
'--batch', str(batch_size),
'--lr', str(lr)
]
train_main()
return "✓ YOLOv6 training completed"
except Exception as e:
return f"❌ YOLOv6 training failed: {e}"
def train_yolox(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
"""Train YOLOX model."""
output_dir.mkdir(parents=True, exist_ok=True)
try:
from train_yolox import main as train_main
import sys
sys.argv = [
'train_yolox.py',
'--data', str(dataset_dir / 'data.yaml'),
'--output', str(output_dir),
'--model', model_size,
'--epochs', str(epochs),
'--batch', str(batch_size),
'--lr', str(lr)
]
train_main()
return "✓ YOLOX training completed"
except Exception as e:
return f"❌ YOLOX training failed: {e}"
def main():
parser = argparse.ArgumentParser(description="Train object detection models for wood knot detection")
parser.add_argument(
'--framework',
choices=['rf-detr', 'rt-detr', 'yolov6', 'yolox'],
default='rt-detr',
help='Model framework to train'
)
parser.add_argument(
'--dataset',
type=Path,
required=True,
help='Path to prepared dataset directory'
)
parser.add_argument(
'--output',
type=Path,
default=Path('runs/training'),
help='Output directory for trained model'
)
parser.add_argument(
'--model-size',
choices=['nano', 'small', 'medium', 'base'],
default='small',
help='Model size/variant'
)
parser.add_argument(
'--epochs',
type=int,
default=20,
help='Number of training epochs'
)
parser.add_argument(
'--batch-size',
type=int,
default=4,
help='Batch size for training'
)
parser.add_argument(
'--lr',
type=float,
default=1e-4,
help='Learning rate'
)
parser.add_argument(
'--prepare-dataset',
action='store_true',
help='Prepare dataset from annotations first'
)
parser.add_argument(
'--images-dir',
type=Path,
default=Path('IMAGE'),
help='Images directory (for --prepare-dataset)'
)
parser.add_argument(
'--annotations',
type=Path,
default=Path('annotations.json'),
help='Annotations file (for --prepare-dataset)'
)
args = parser.parse_args()
if args.prepare_dataset:
print("Preparing dataset...")
result = prepare_training_dataset(
args.images_dir,
args.annotations,
args.dataset
)
print(result)
if "" in result:
return
print(f"Training {args.framework.upper()} model...")
print(f"Dataset: {args.dataset}")
print(f"Output: {args.output}")
print(f"Model size: {args.model_size}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.lr}")
# Train based on framework
if args.framework == 'rf-detr':
result = train_rfdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
elif args.framework == 'rt-detr':
result = train_rtdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
elif args.framework == 'yolov6':
result = train_yolov6(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
elif args.framework == 'yolox':
result = train_yolox(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
print(result)
if __name__ == "__main__":
main()</content>
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/train_model.py