Optimizing parameters in a WOFOST crop model using diffWOFOST
This Jupyter notebook demonstrates the optimization of parameters in a
differentiable model using the diffwofost package. The package provides
differentiable implementations of the WOFOST model and its associated
sub-models. As diffwofost is under active development, this notebook focuses on
one sub-models: phenology.
1. Phenology¶
In this section, we will demonstrate how to optimize the parameters TSUMEM, TBASEM, TSUM1 and TSUM2in
phenology model using a differentiable version of phenology.
The optimization will be done using the Adam optimizer from torch.optim.
1.1 software requirements¶
To run this notebook, we need to install the diffwofost; the differentiable
version of WOFOST models. Since the package is constantly under development, make
sure you have the latest version of diffwofost installed in your
python environment. You can install it using pip:
# install diffwofost
!pip install diffwofost
# ---- import libraries ----
import copy
import torch
import numpy
from pathlib import Path
from diffwofost.physical_models.config import Configuration
from diffwofost.physical_models.crop.phenology import DVS_Phenology
from diffwofost.physical_models.utils import EngineTestHelper
from diffwofost.physical_models.utils import prepare_engine_input
from diffwofost.physical_models.utils import get_test_data
# --- run on CPU ------
from diffwofost.physical_models.config import ComputeConfig
ComputeConfig.set_device('cpu')
# ---- disable a warning: this will be fixed in the future ----
import warnings
warnings.filterwarnings("ignore", message="To copy construct from a tensor.*")
1.2. Data¶
A test dataset of DVS (Development stage) will be used to optimize the parameters:
TSUMEM: Temperature sum from sowing to emergence,TBASEM: Base temperature for emergence,TSUM1: Temperature sum from emergence to anthesis,TSUM2: Temperature sum from anthesis to maturity.
The data is stored in the PCSE tests folder, and can be downloaded from the PCSE repository.
You can select any of the files related to phenology model with a file name that follows the pattern
test_phenology_wofost72_*.yaml. Each file contains different data depending on the location and crop type.
For example, you can download the file "test_phenology_wofost72_01.yaml" as:
import urllib.request
filename = "test_phenology_wofost72_17.yaml"
url = f"https://raw.githubusercontent.com/ajwdewit/pcse/refs/heads/master/tests/test_data/{filename}"
urllib.request.urlretrieve(url, filename)
print(f"Downloaded: {filename}")
Downloaded: test_phenology_wofost72_17.yaml
# ---- Check the path to the files that are downloaded as explained above ----
test_data_path = "test_phenology_wofost72_17.yaml"
# ---- Here we read the test data and set some variables ----
test_data = get_test_data(test_data_path)
crop_model_params = [
"TSUMEM",
"TBASEM",
"TEFFMX",
"TSUM1",
"TSUM2",
"IDSL",
"DLO",
"DLC",
"DVSI",
"DVSEND",
"DTSMTB",
"VERNSAT",
"VERNBASE",
"VERNDVS",
]
(crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = (
prepare_engine_input(test_data, crop_model_params)
)
expected_results = test_data["ModelResults"]
expected_dvs = torch.tensor([float(item["DVS"]) for item in expected_results], dtype=torch.float32
) # shape: [time_steps]
# ---- don't change this: in this config file we specify the differentiable version of DVS_Phenology ----
phenology_config = Configuration(
CROP=DVS_Phenology,
OUTPUT_VARS=["DVR", "DVS", "TSUM", "TSUME", "VERN"],
)
1.3. Helper classes/functions¶
The model parameters should stay in a valid range. To ensure this, we will use
BoundedParameter class with (min, max) and initial values for each
parameter. You may change these values depending on the crop type and
location. But don't use a very small range, otherwise gradients will be very
small and the optimization will be very slow.
# ---- Adjust the values if needed ----
TSUMEM_MIN, TSUMEM_MAX, TSUMEM_INIT = (0.0, 200, 90)
TBASEM_MIN, TBASEM_MAX, TBASEM_INIT = (0.0, 10.0, 2.0)
TSUM1_MIN, TSUM1_MAX, TSUM1_INIT = (0.0, 1000, 800)
TSUM2_MIN, TSUM2_MAX, TSUM2_INIT = (0.0, 1000, 800)
# ---- Helper for bounded parameters ----
class BoundedParameter(torch.nn.Module):
def __init__(self, low, high, init_value):
super().__init__()
self.low = low
self.high = high
# Normalize to [0, 1]
init_norm = (init_value - low) / (high - low)
# Parameter in raw logit space
self.raw = torch.nn.Parameter(torch.logit(torch.tensor(init_norm, dtype=torch.float32), eps=1e-6))
def forward(self):
return self.low + (self.high - self.low) * torch.sigmoid(self.raw)
Another helper class is OptDiffPhenology which is a subclass of torch.nn.Module.
We use this class to wrap the EngineTestHelper function and make it easier to run the model phenology.
# ---- Wrap the model with torch.nn.Module----
class OptDiffPhenology(torch.nn.Module):
def __init__(self, crop_model_params_provider, weather_data_provider, agro_management_inputs, phenology_config):
super().__init__()
self.crop_model_params_provider = crop_model_params_provider
self.weather_data_provider = weather_data_provider
self.agro_management_inputs = agro_management_inputs
self.config = phenology_config
# bounded parameters
self.TSUMEM = BoundedParameter(TSUMEM_MIN, TSUMEM_MAX, TSUMEM_INIT)
self.TBASEM = BoundedParameter(TBASEM_MIN, TBASEM_MAX, TBASEM_INIT)
self.TSUM1 = BoundedParameter(TSUM1_MIN, TSUM1_MAX, TSUM1_INIT)
self.TSUM2 = BoundedParameter(TSUM2_MIN, TSUM2_MAX, TSUM2_INIT)
def forward(self):
# currently, copying is needed due to an internal issue in engine
crop_model_params_provider_ = copy.deepcopy(self.crop_model_params_provider)
TSUMEM_val = self.TSUMEM()
TBASEM_val = self.TBASEM()
TSUM1_val = self.TSUM1()
TSUM2_val = self.TSUM2()
# pass new value of parameters to the model
crop_model_params_provider_.set_override("TSUMEM", TSUMEM_val, check=False)
crop_model_params_provider_.set_override("TBASEM", TBASEM_val, check=False)
crop_model_params_provider_.set_override("TSUM1", TSUM1_val, check=False)
crop_model_params_provider_.set_override("TSUM2", TSUM2_val, check=False)
engine = EngineTestHelper(
crop_model_params_provider_,
self.weather_data_provider,
self.agro_management_inputs,
self.config,
)
engine.run_till_terminate()
results = engine.get_output()
return torch.stack([item["DVS"] for item in results]) # shape: [1, time_steps]
# ---- Create model ----
opt_model = OptDiffPhenology(
crop_model_params_provider,
weather_data_provider,
agro_management_inputs,
phenology_config,
)
# ---- Early stopping ----
best_loss = float("inf")
patience = 10 # Number of steps to wait for improvement
patience_counter = 0
min_delta = 1e-4
# ---- Optimizer ----
optimizer = torch.optim.Adam(opt_model.parameters(), lr=0.1)
# ---- We use relative MAE as loss because there are two outputs with different untis ----
denom = torch.mean(torch.abs(expected_dvs))
# Training loop (example)
for step in range(101):
optimizer.zero_grad()
results = opt_model()
# phenology parameters can change the simulation duration
min_len = min(len(results), len(expected_dvs))
if len(results) != len(expected_dvs):
print(f"Step {step}: duration mismatch ({len(results)} vs {len(expected_dvs)}).")
mae = torch.mean(torch.abs(results[:min_len] - expected_dvs[:min_len]))
loss = mae / denom # example: relative mean absolute error
loss.backward()
optimizer.step()
print(
f"Step {step}, Loss {loss.item():.4f}, "
f"TSUMEM {opt_model.TSUMEM().item():.4f}, "
f"TBASEM {opt_model.TBASEM().item():.4f}, "
f"TSUM1 {opt_model.TSUM1().item():.4f}, "
f"TSUM2 {opt_model.TSUM2().item():.4f},"
)
# Early stopping logic
if loss.item() < best_loss - min_delta:
best_loss = loss.item()
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at step {step}")
print(f"duration (model {len(results)} vs test {len(expected_dvs)}).")
break
Step 0: duration mismatch (260 vs 279). Step 0, Loss 0.1490, TSUMEM 85.0787, TBASEM 1.8448, TSUM1 815.5215, TSUM2 815.5215, Step 1: duration mismatch (262 vs 279). Step 1, Loss 0.1348, TSUMEM 80.2344, TBASEM 1.6999, TSUM1 830.0543, TSUM2 830.0643, Step 2: duration mismatch (263 vs 279). Step 2, Loss 0.1197, TSUMEM 77.2860, TBASEM 1.6076, TSUM1 843.6052, TSUM2 843.6012, Step 3: duration mismatch (264 vs 279). Step 3, Loss 0.1147, TSUMEM 76.5338, TBASEM 1.5720, TSUM1 856.1688, TSUM2 856.1740, Step 4: duration mismatch (266 vs 279). Step 4, Loss 0.1019, TSUMEM 77.1810, TBASEM 1.5731, TSUM1 867.7785, TSUM2 867.8158, Step 5: duration mismatch (267 vs 279). Step 5, Loss 0.0881, TSUMEM 78.6763, TBASEM 1.5976, TSUM1 878.4762, TSUM2 878.5369, Step 6: duration mismatch (268 vs 279). Step 6, Loss 0.0830, TSUMEM 80.7683, TBASEM 1.6402, TSUM1 888.2892, TSUM2 888.3950, Step 7: duration mismatch (269 vs 279). Step 7, Loss 0.0698, TSUMEM 82.9896, TBASEM 1.6870, TSUM1 897.2725, TSUM2 897.4227, Step 8: duration mismatch (270 vs 279). Step 8, Loss 0.0568, TSUMEM 84.5758, TBASEM 1.7161, TSUM1 905.4835, TSUM2 905.6589, Step 9: duration mismatch (271 vs 279). Step 9, Loss 0.0521, TSUMEM 84.9177, TBASEM 1.7125, TSUM1 912.9635, TSUM2 913.1725, Step 10: duration mismatch (271 vs 279). Step 10, Loss 0.0480, TSUMEM 84.3238, TBASEM 1.6843, TSUM1 919.7631, TSUM2 920.0091, Step 11: duration mismatch (273 vs 279). Step 11, Loss 0.0381, TSUMEM 83.4182, TBASEM 1.6478, TSUM1 925.9421, TSUM2 926.2325, Step 12: duration mismatch (273 vs 279). Step 12, Loss 0.0355, TSUMEM 82.3086, TBASEM 1.6063, TSUM1 931.5499, TSUM2 931.8865, Step 13: duration mismatch (273 vs 279). Step 13, Loss 0.0324, TSUMEM 81.3026, TBASEM 1.5680, TSUM1 936.6345, TSUM2 937.0161, Step 14: duration mismatch (275 vs 279). Step 14, Loss 0.0245, TSUMEM 80.8495, TBASEM 1.5439, TSUM1 941.2473, TSUM2 941.6774, Step 15: duration mismatch (275 vs 279). Step 15, Loss 0.0220, TSUMEM 81.1065, TBASEM 1.5381, TSUM1 945.4302, TSUM2 945.9092, Step 16: duration mismatch (275 vs 279). Step 16, Loss 0.0197, TSUMEM 81.9637, TBASEM 1.5478, TSUM1 949.2226, TSUM2 949.7485, Step 17: duration mismatch (276 vs 279). Step 17, Loss 0.0103, TSUMEM 83.1409, TBASEM 1.5657, TSUM1 952.6663, TSUM2 953.2308, Step 18: duration mismatch (277 vs 279). Step 18, Loss 0.0093, TSUMEM 84.1272, TBASEM 1.5787, TSUM1 955.4659, TSUM2 956.3961, Step 19: duration mismatch (277 vs 279). Step 19, Loss 0.0093, TSUMEM 84.7385, TBASEM 1.5820, TSUM1 957.7150, TSUM2 959.2729, Step 20: duration mismatch (277 vs 279). Step 20, Loss 0.0093, TSUMEM 85.0120, TBASEM 1.5765, TSUM1 959.5129, TSUM2 961.8885, Step 21: duration mismatch (277 vs 279). Step 21, Loss 0.0092, TSUMEM 84.9791, TBASEM 1.5633, TSUM1 960.9411, TSUM2 964.2680, Step 22: duration mismatch (277 vs 279). Step 22, Loss 0.0091, TSUMEM 84.6666, TBASEM 1.5432, TSUM1 962.0599, TSUM2 966.4341, Step 23: duration mismatch (278 vs 279). Step 23, Loss 0.0090, TSUMEM 84.0982, TBASEM 1.5171, TSUM1 962.9180, TSUM2 968.4114, Step 24, Loss 0.0078, TSUMEM 83.4926, TBASEM 1.4905, TSUM1 963.5505, TSUM2 970.0585, Step 25, Loss 0.0082, TSUMEM 83.2271, TBASEM 1.4719, TSUM1 963.9872, TSUM2 971.4006, Step 26, Loss 0.0086, TSUMEM 83.4078, TBASEM 1.4639, TSUM1 964.2517, TSUM2 972.4788, Step 27, Loss 0.0090, TSUMEM 83.9896, TBASEM 1.4651, TSUM1 964.3623, TSUM2 973.3393, Step 28, Loss 0.0092, TSUMEM 84.8013, TBASEM 1.4715, TSUM1 964.3331, TSUM2 974.0173, Step 29, Loss 0.0093, TSUMEM 85.4506, TBASEM 1.4742, TSUM1 964.1751, TSUM2 974.5405, Step 30, Loss 0.0094, TSUMEM 85.9211, TBASEM 1.4726, TSUM1 963.8970, TSUM2 974.9305, Step 31, Loss 0.0094, TSUMEM 86.0486, TBASEM 1.4633, TSUM1 963.5046, TSUM2 975.2042, Step 32, Loss 0.0093, TSUMEM 85.8631, TBASEM 1.4471, TSUM1 963.0023, TSUM2 975.3755, Step 33, Loss 0.0092, TSUMEM 85.6006, TBASEM 1.4293, TSUM1 962.3921, TSUM2 975.4550, Step 34, Loss 0.0090, TSUMEM 85.2661, TBASEM 1.4101, TSUM1 961.6753, TSUM2 975.4510, Early stopping at step 34 duration (model 279 vs test 279).
# ---- validate the results using test data ----
print(
f"Actual TSUMEM {crop_model_params_provider["TSUMEM"].item():.4f}",
f"TBASEM {crop_model_params_provider["TBASEM"].item():.4f}",
f"Actual TSUM1 {crop_model_params_provider["TSUM1"].item():.4f}",
f"TSUM2 {crop_model_params_provider["TSUM2"].item():.4f}"
)
Actual TSUMEM 110.0000 TBASEM 0.0000 Actual TSUM1 950.0000 TSUM2 991.0000