removed gradio
This commit is contained in:
317
train_model.py
Executable file
317
train_model.py
Executable file
@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training script for wood knot detection models.
|
||||
|
||||
Supports: RF-DETR, RT-DETR, YOLOv6, YOLOX
|
||||
All models are MIT/Apache 2.0 licensed - free for commercial use.
|
||||
|
||||
Usage:
|
||||
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def prepare_training_dataset(
|
||||
images_dir: Path,
|
||||
annotations_file: Path,
|
||||
output_dir: Path,
|
||||
train_split: float = 0.8,
|
||||
valid_split: float = 0.1
|
||||
) -> str:
|
||||
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load annotations
|
||||
with annotations_file.open() as f:
|
||||
annotations = json.load(f)
|
||||
|
||||
# Get all image files
|
||||
image_files = []
|
||||
for img_name in annotations.keys():
|
||||
img_path = images_dir / img_name
|
||||
if img_path.exists():
|
||||
image_files.append(img_name)
|
||||
|
||||
if not image_files:
|
||||
return "❌ No annotated images found"
|
||||
|
||||
# Shuffle and split
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(image_files)
|
||||
|
||||
n_total = len(image_files)
|
||||
n_train = int(n_total * train_split)
|
||||
n_valid = int(n_total * valid_split)
|
||||
n_test = n_total - n_train - n_valid
|
||||
|
||||
splits = {
|
||||
'train': image_files[:n_train],
|
||||
'valid': image_files[n_train:n_train + n_valid],
|
||||
'test': image_files[n_train + n_valid:]
|
||||
}
|
||||
|
||||
# Create directories
|
||||
for split in ['train', 'valid', 'test']:
|
||||
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy images and create labels
|
||||
for split, img_names in splits.items():
|
||||
for img_name in img_names:
|
||||
# Copy image
|
||||
src_img = images_dir / img_name
|
||||
dst_img = output_dir / split / 'images' / img_name
|
||||
shutil.copy2(src_img, dst_img)
|
||||
|
||||
# Convert annotations to YOLO format
|
||||
boxes = annotations.get(img_name, [])
|
||||
if boxes:
|
||||
img_width, img_height = get_image_size(src_img)
|
||||
yolo_lines = []
|
||||
for box in boxes:
|
||||
if isinstance(box, dict) and "bbox" in box:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
else:
|
||||
x1, y1, x2, y2 = box
|
||||
# Convert to YOLO format (normalized center x, center y, width, height)
|
||||
x_center = (x1 + x2) / 2 / img_width
|
||||
y_center = (y1 + y2) / 2 / img_height
|
||||
width = (x2 - x1) / img_width
|
||||
height = (y2 - y1) / img_height
|
||||
yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
# Save label file
|
||||
label_file = output_dir / split / 'labels' / img_name.replace('.jpg', '.txt').replace('.png', '.txt')
|
||||
with label_file.open('w') as f:
|
||||
f.write('\n'.join(yolo_lines))
|
||||
|
||||
# Create data.yaml for YOLO models
|
||||
data_yaml = f"""train: {output_dir}/train/images
|
||||
val: {output_dir}/valid/images
|
||||
test: {output_dir}/test/images
|
||||
|
||||
nc: 1
|
||||
names: ['knot']
|
||||
"""
|
||||
|
||||
with (output_dir / 'data.yaml').open('w') as f:
|
||||
f.write(data_yaml)
|
||||
|
||||
return f"✓ Dataset prepared: {n_train} train, {n_valid} valid, {n_test} test images"
|
||||
|
||||
|
||||
def get_image_size(img_path: Path) -> tuple[int, int]:
|
||||
"""Get image dimensions."""
|
||||
from PIL import Image
|
||||
with Image.open(img_path) as img:
|
||||
return img.size
|
||||
|
||||
|
||||
def train_rfdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train RF-DETR model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Import and run training
|
||||
try:
|
||||
from train_rfdetr import main as train_main
|
||||
import sys
|
||||
|
||||
# Set up arguments
|
||||
sys.argv = [
|
||||
'train_rfdetr.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ RF-DETR training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RF-DETR training failed: {e}"
|
||||
|
||||
|
||||
def train_rtdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train RT-DETR model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_rtdetr import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_rtdetr.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ RT-DETR training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RT-DETR training failed: {e}"
|
||||
|
||||
|
||||
def train_yolov6(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train YOLOv6 model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_yolov6 import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_yolov6.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ YOLOv6 training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOv6 training failed: {e}"
|
||||
|
||||
|
||||
def train_yolox(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train YOLOX model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_yolox import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_yolox.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ YOLOX training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOX training failed: {e}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train object detection models for wood knot detection")
|
||||
parser.add_argument(
|
||||
'--framework',
|
||||
choices=['rf-detr', 'rt-detr', 'yolov6', 'yolox'],
|
||||
default='rt-detr',
|
||||
help='Model framework to train'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Path to prepared dataset directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=Path,
|
||||
default=Path('runs/training'),
|
||||
help='Output directory for trained model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model-size',
|
||||
choices=['nano', 'small', 'medium', 'base'],
|
||||
default='small',
|
||||
help='Model size/variant'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Number of training epochs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Batch size for training'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help='Learning rate'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--prepare-dataset',
|
||||
action='store_true',
|
||||
help='Prepare dataset from annotations first'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--images-dir',
|
||||
type=Path,
|
||||
default=Path('IMAGE'),
|
||||
help='Images directory (for --prepare-dataset)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--annotations',
|
||||
type=Path,
|
||||
default=Path('annotations.json'),
|
||||
help='Annotations file (for --prepare-dataset)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prepare_dataset:
|
||||
print("Preparing dataset...")
|
||||
result = prepare_training_dataset(
|
||||
args.images_dir,
|
||||
args.annotations,
|
||||
args.dataset
|
||||
)
|
||||
print(result)
|
||||
if "❌" in result:
|
||||
return
|
||||
|
||||
print(f"Training {args.framework.upper()} model...")
|
||||
print(f"Dataset: {args.dataset}")
|
||||
print(f"Output: {args.output}")
|
||||
print(f"Model size: {args.model_size}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch_size}")
|
||||
print(f"Learning rate: {args.lr}")
|
||||
|
||||
# Train based on framework
|
||||
if args.framework == 'rf-detr':
|
||||
result = train_rfdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'rt-detr':
|
||||
result = train_rtdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'yolov6':
|
||||
result = train_yolov6(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'yolox':
|
||||
result = train_yolox(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()</content>
|
||||
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/train_model.py
|
||||
Reference in New Issue
Block a user