2025-12-22 14:11:39 -07:00
|
|
|
"""
|
|
|
|
|
Train YOLOv6 for knot detection (MIT license - free for commercial use).
|
|
|
|
|
|
|
|
|
|
YOLOv6 is from Meituan, optimized for deployment on edge devices.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
python train_yolov6.py --dataset-dir dataset_prepared --model yolov6n --epochs 100
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_yolov6(
|
|
|
|
|
dataset_dir: Path,
|
|
|
|
|
output_dir: Path,
|
|
|
|
|
model_name: str = "yolov6n",
|
|
|
|
|
epochs: int = 100,
|
|
|
|
|
batch_size: int = 8,
|
|
|
|
|
img_size: int = 640,
|
|
|
|
|
learning_rate: float = 1e-2,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Train YOLOv6 model.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dataset_dir: Path to dataset with train/valid/test splits
|
|
|
|
|
output_dir: Where to save checkpoints
|
|
|
|
|
model_name: One of ['yolov6n', 'yolov6s', 'yolov6m', 'yolov6l']
|
|
|
|
|
epochs: Number of training epochs
|
|
|
|
|
batch_size: Batch size
|
|
|
|
|
img_size: Input image size
|
|
|
|
|
learning_rate: Learning rate
|
|
|
|
|
"""
|
|
|
|
|
# Install YOLOv6 if not already installed
|
|
|
|
|
try:
|
|
|
|
|
import yolov6
|
|
|
|
|
except ImportError:
|
|
|
|
|
print("Installing YOLOv6...")
|
|
|
|
|
subprocess.check_call([
|
|
|
|
|
sys.executable, "-m", "pip", "install",
|
|
|
|
|
"git+https://github.com/meituan/YOLOv6.git"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# Validate dataset structure
|
|
|
|
|
train_dir = dataset_dir / "train"
|
|
|
|
|
valid_dir = dataset_dir / "valid"
|
|
|
|
|
|
|
|
|
|
if not train_dir.exists() or not valid_dir.exists():
|
|
|
|
|
raise ValueError(f"Dataset must have train/ and valid/ directories")
|
|
|
|
|
|
|
|
|
|
train_ann = train_dir / "_annotations.coco.json"
|
|
|
|
|
valid_ann = valid_dir / "_annotations.coco.json"
|
|
|
|
|
|
|
|
|
|
if not train_ann.exists():
|
|
|
|
|
raise ValueError(f"Missing {train_ann}")
|
|
|
|
|
if not valid_ann.exists():
|
|
|
|
|
raise ValueError(f"Missing {valid_ann}")
|
|
|
|
|
|
|
|
|
|
# Create output directory
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Create data config file for YOLOv6
|
|
|
|
|
data_yaml = output_dir / "data.yaml"
|
|
|
|
|
with data_yaml.open("w") as f:
|
|
|
|
|
f.write(f"""train: {train_dir.absolute()}
|
|
|
|
|
val: {valid_dir.absolute()}
|
|
|
|
|
|
2025-12-22 14:48:17 -07:00
|
|
|
nc: 10
|
|
|
|
|
names: ['Live knot', 'Dead knot', 'Knot with crack', 'Crack', 'Resin', 'Marrow', 'Quartzity', 'Knot missing', 'Blue stain', 'Overgrown']
|
2025-12-22 14:11:39 -07:00
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
|
|
print(f"Training YOLOv6-{model_name}")
|
|
|
|
|
print(f"Dataset: {dataset_dir}")
|
|
|
|
|
print(f"Output: {output_dir}")
|
|
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
|
|
|
|
# Map model names
|
|
|
|
|
model_map = {
|
|
|
|
|
"yolov6n": "yolov6n",
|
|
|
|
|
"yolov6s": "yolov6s",
|
|
|
|
|
"yolov6m": "yolov6m",
|
|
|
|
|
"yolov6l": "yolov6l",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if model_name not in model_map:
|
|
|
|
|
raise ValueError(f"Model must be one of {list(model_map.keys())}")
|
|
|
|
|
|
|
|
|
|
# Build training command
|
|
|
|
|
yolov6_dir = Path(sys.executable).parent.parent / "YOLOv6"
|
|
|
|
|
train_script = yolov6_dir / "tools/train.py"
|
|
|
|
|
|
|
|
|
|
cmd = [
|
|
|
|
|
sys.executable,
|
|
|
|
|
str(train_script),
|
|
|
|
|
"--batch", str(batch_size),
|
|
|
|
|
"--conf", str(yolov6_dir / f"configs/{model_name}.py"),
|
|
|
|
|
"--data", str(data_yaml),
|
|
|
|
|
"--epochs", str(epochs),
|
|
|
|
|
"--device", "0",
|
|
|
|
|
"--name", "yolov6_training",
|
|
|
|
|
"--output-dir", str(output_dir),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
print(f"Running: {' '.join(cmd)}\n")
|
|
|
|
|
|
|
|
|
|
result = subprocess.run(cmd)
|
|
|
|
|
|
|
|
|
|
if result.returncode == 0:
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
|
|
print(f"✓ Training complete!")
|
|
|
|
|
print(f"Weights saved in: {output_dir}/yolov6_training")
|
|
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
else:
|
|
|
|
|
print(f"\n❌ Training failed with exit code {result.returncode}")
|
|
|
|
|
|
|
|
|
|
return result.returncode == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train YOLOv6 for knot detection")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--dataset-dir",
|
|
|
|
|
type=Path,
|
|
|
|
|
required=True,
|
|
|
|
|
help="Path to dataset directory with train/valid/test splits"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output-dir",
|
|
|
|
|
type=Path,
|
|
|
|
|
default=Path("runs/yolov6_training"),
|
|
|
|
|
help="Output directory for checkpoints and logs"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--model",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["yolov6n", "yolov6s", "yolov6m", "yolov6l"],
|
|
|
|
|
default="yolov6n",
|
|
|
|
|
help="YOLOv6 model variant (n=smallest/fastest, l=largest/most accurate)"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
|
|
|
|
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
|
|
|
|
|
parser.add_argument("--img-size", type=int, default=640, help="Input image size")
|
|
|
|
|
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
train_yolov6(
|
|
|
|
|
dataset_dir=args.dataset_dir,
|
|
|
|
|
output_dir=args.output_dir,
|
|
|
|
|
model_name=args.model,
|
|
|
|
|
epochs=args.epochs,
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
img_size=args.img_size,
|
|
|
|
|
learning_rate=args.lr,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|