I'm building a web app with FastAPI + async/await Python backend. Users upload leaf photos via API and the server should return: 1) plant species, 2) disease label or "healthy".

Constraints:

  1. Generalization: Must handle multiple crops. Users can upload "any" plant leaf, not just tomato/corn. Target 15+ species.

  2. Server inference: Runs on GPU server, not mobile. Latency 1-2s is acceptable, so model size isn't a bottleneck.

  3. Pre-trained + 100% free: Need open-source weights for transfer learning. No paid APIs. License must allow commercial use.

  4. Dataset: Starting with PlantVillage dataset + ~2,000 custom field images. Lab images vs real field images is a domain shift issue.

  5. Tech stack: PyTorch + timm library. Inference runs in async endpoints, so I use run_in_executor to avoid blocking.

What I tried: Fine-tuned ResNet50 on PlantVillage. 95% accuracy on lab images, but it drops to ~62% on field images. Overfitting to clean backgrounds.

Questions:

  1. For multi-crop + multi-disease, is a 2-stage approach better: Model A for species ID, Model B for disease per species? Or one multi-label model?

  2. Between ConvNeXt-Base, Swin-Base, and ViT-Base, which fine-tunes best on PlantVillage + field data for accuracy in 2025?

  3. Are there plant-specific foundation models/checkpoints better than ImageNet pre-training for this domain?

I'm looking for architecture + dataset + fine-tuning strategy advice, not code.