Files
saw_mill_knot_detection/train_yolov6.py

164 lines
4.6 KiB
Python

"""
Train YOLOv6 for knot detection (MIT license - free for commercial use).
YOLOv6 is from Meituan, optimized for deployment on edge devices.
Usage:
python train_yolov6.py --dataset-dir dataset_prepared --model yolov6n --epochs 100
"""
import argparse
from pathlib import Path
import subprocess
import sys
def train_yolov6(
dataset_dir: Path,
output_dir: Path,
model_name: str = "yolov6n",
epochs: int = 100,
batch_size: int = 8,
img_size: int = 640,
learning_rate: float = 1e-2,
):
"""
Train YOLOv6 model.
Args:
dataset_dir: Path to dataset with train/valid/test splits
output_dir: Where to save checkpoints
model_name: One of ['yolov6n', 'yolov6s', 'yolov6m', 'yolov6l']
epochs: Number of training epochs
batch_size: Batch size
img_size: Input image size
learning_rate: Learning rate
"""
# Install YOLOv6 if not already installed
try:
import yolov6
except ImportError:
print("Installing YOLOv6...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/meituan/YOLOv6.git"
])
# Validate dataset structure
train_dir = dataset_dir / "train"
valid_dir = dataset_dir / "valid"
if not train_dir.exists() or not valid_dir.exists():
raise ValueError(f"Dataset must have train/ and valid/ directories")
train_ann = train_dir / "_annotations.coco.json"
valid_ann = valid_dir / "_annotations.coco.json"
if not train_ann.exists():
raise ValueError(f"Missing {train_ann}")
if not valid_ann.exists():
raise ValueError(f"Missing {valid_ann}")
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Create data config file for YOLOv6
data_yaml = output_dir / "data.yaml"
with data_yaml.open("w") as f:
f.write(f"""train: {train_dir.absolute()}
val: {valid_dir.absolute()}
nc: 1
names: ['knot']
""")
print(f"\n{'='*60}")
print(f"Training YOLOv6-{model_name}")
print(f"Dataset: {dataset_dir}")
print(f"Output: {output_dir}")
print(f"{'='*60}\n")
# Map model names
model_map = {
"yolov6n": "yolov6n",
"yolov6s": "yolov6s",
"yolov6m": "yolov6m",
"yolov6l": "yolov6l",
}
if model_name not in model_map:
raise ValueError(f"Model must be one of {list(model_map.keys())}")
# Build training command
yolov6_dir = Path(sys.executable).parent.parent / "YOLOv6"
train_script = yolov6_dir / "tools/train.py"
cmd = [
sys.executable,
str(train_script),
"--batch", str(batch_size),
"--conf", str(yolov6_dir / f"configs/{model_name}.py"),
"--data", str(data_yaml),
"--epochs", str(epochs),
"--device", "0",
"--name", "yolov6_training",
"--output-dir", str(output_dir),
]
print(f"Running: {' '.join(cmd)}\n")
result = subprocess.run(cmd)
if result.returncode == 0:
print(f"\n{'='*60}")
print(f"✓ Training complete!")
print(f"Weights saved in: {output_dir}/yolov6_training")
print(f"{'='*60}\n")
else:
print(f"\n❌ Training failed with exit code {result.returncode}")
return result.returncode == 0
def main():
parser = argparse.ArgumentParser(description="Train YOLOv6 for knot detection")
parser.add_argument(
"--dataset-dir",
type=Path,
required=True,
help="Path to dataset directory with train/valid/test splits"
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("runs/yolov6_training"),
help="Output directory for checkpoints and logs"
)
parser.add_argument(
"--model",
type=str,
choices=["yolov6n", "yolov6s", "yolov6m", "yolov6l"],
default="yolov6n",
help="YOLOv6 model variant (n=smallest/fastest, l=largest/most accurate)"
)
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
parser.add_argument("--img-size", type=int, default=640, help="Input image size")
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate")
args = parser.parse_args()
train_yolov6(
dataset_dir=args.dataset_dir,
output_dir=args.output_dir,
model_name=args.model,
epochs=args.epochs,
batch_size=args.batch_size,
img_size=args.img_size,
learning_rate=args.lr,
)
if __name__ == "__main__":
main()