Skip to content

API reference

Block

Bases: Module

A basic block for ResNet architecture. This block consists of two convolutional layers with batch normalization and ReLU activation. The first layer applies a 3x3 convolution, and the second layer applies another 3x3 convolution. The block also supports downsampling through an optional identity downsample layer. The expansion factor is set to 1, meaning the output channels are the same as the input channels.

taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py

Source code in windscangeo\Models.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class Block(nn.Module):
    """
    A basic block for ResNet architecture.
    This block consists of two convolutional layers with batch normalization
    and ReLU activation. The first layer applies a 3x3 convolution, and the
    second layer applies another 3x3 convolution. The block also supports
    downsampling through an optional identity downsample layer. The expansion
    factor is set to 1, meaning the output channels are the same as the input
    channels.

    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    """

    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
      identity = x.clone()

      x = self.relu(self.batch_norm2(self.conv1(x)))
      x = self.batch_norm2(self.conv2(x))

      if self.i_downsample is not None:
          identity = self.i_downsample(identity)
      print(x.shape)
      print(identity.shape)
      x += identity
      x = self.relu(x)
      return x

Bottleneck

Bases: Module

taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py A bottleneck block for ResNet architecture. This block consists of three convolutional layers with batch normalization and ReLU activation. The first layer reduces the number of channels, the second layer applies a 3x3 convolution, and the third layer expands the number of channels back to the original size. The block also supports downsampling through an optional identity downsample layer.

Source code in windscangeo\Models.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class Bottleneck(nn.Module):
    """
    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    A bottleneck block for ResNet architecture.
    This block consists of three convolutional layers with batch normalization
    and ReLU activation. The first layer reduces the number of channels,
    the second layer applies a 3x3 convolution, and the third layer expands
    the number of channels back to the original size. The block also supports
    downsampling through an optional identity downsample layer.
    """
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))

        x = self.relu(self.batch_norm2(self.conv2(x)))

        x = self.conv3(x)
        x = self.batch_norm3(x)

        #downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        #add identity
        x+=identity
        x=self.relu(x)

        return x

ConventionalCNN

Bases: Module

A simple CNN for image regression tasks. This model consists of a series of convolutional layers followed by fully connected layers. It is designed to process images and output a single regression value (e.g., wind speed).

Source code in windscangeo\Models.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ConventionalCNN(nn.Module):
    """
    A simple CNN for image regression tasks.
    This model consists of a series of convolutional layers followed by
    fully connected layers. It is designed to process images and output a
    single regression value (e.g., wind speed).
    """
    def __init__(
        self,
        image_height: int,
        image_width: int,
        features_cnn: list[int],
        kernel_size: int,
        in_channels: int,
        activation_cnn: nn.Module = nn.ReLU(),
        activation_final: nn.Module = nn.Identity(),
        stride: int = 1,
        dropout_rate: float = 0.2,
    ):
        super().__init__()
        self.activation_cnn = activation_cnn
        self.activation_final = activation_final
        self.dropout_rate = dropout_rate

        # ------- Convolutional backbone -------
        self.convs = nn.ModuleList()
        for feature in features_cnn:
            self.convs.extend(
                [
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=feature,
                        kernel_size=kernel_size,
                        padding=1,
                        stride=stride,
                    ),
                    self.activation_cnn,
                    nn.MaxPool2d(kernel_size=2),
                    nn.Dropout(self.dropout_rate),
                ]
            )
            in_channels = feature

        # ------- Classifier / regressor head -------
        self.flattened_size = self._get_flattened_size(image_height, image_width)
        self.fc_cnn = nn.Linear(self.flattened_size, 64)
        self.dropout_cnn = nn.Dropout(self.dropout_rate)
        self.head = nn.Sequential(
            nn.Linear(64, 16),
            self.activation_cnn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(16, 1),
            self.activation_final,
        )

    def _get_flattened_size(self, h, w):
        x = torch.zeros(1, self.convs[0].in_channels, h, w)
        for layer in self.convs:
            x = layer(x)
        return x.numel()

    def forward(self, image):
        x = image
        for layer in self.convs:
            x = layer(x)
        x = x.view(x.size(0), -1)              # flatten
        x = self.activation_cnn(self.fc_cnn(x))
        x = self.dropout_cnn(x)
        out = self.head(x)
        return out

H5pyDataset

Bases: Dataset

A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory. Need to work on Zarr integration for better performance

Parameters:

Name Type Description Default
h5_file_path str

Path to the HDF5 file.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class H5pyDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory.
    Need to work on Zarr integration for better performance

    Args:
        h5_file_path (str): Path to the HDF5 file.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, h5_file_path, transform=None):
        self.h5_file_path = h5_file_path
        self.transform = transform
        self.file = None  # Will be initialized per worker
        with h5py.File(self.h5_file_path, 'r') as f:
            self.length = len(f['targets'])

    def _ensure_file(self):
        if self.file is None:
            self.file = h5py.File(self.h5_file_path, 'r')

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        self._ensure_file()

        image = self.file['images'][idx]
        target = self.file['targets'][idx]

        image = torch.tensor(image, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, target

    def __del__(self):
        if self.file:
            self.file.close()

Img2Seq

Bases: Module

This layers takes a batch of images as input and returns a batch of sequences

Shape

input: (b, h, w, c) output: (b, s, d)

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class Img2Seq(nn.Module):
    """
    This layers takes a batch of images as input and
    returns a batch of sequences

    Shape:
        input: (b, h, w, c)
        output: (b, s, d)

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    def __init__(self, img_size, patch_size, n_channels, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        n_tokens = nh * nw

        token_dim = patch_size[0] * patch_size[1] * n_channels
        self.linear = nn.Linear(token_dim, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_emb = nn.Parameter(torch.randn(n_tokens, d_model))

    def __call__(self, batch):
        batch = patchify(batch, self.patch_size)

        b, c, nh, nw, ph, pw = batch.shape

        # Flattening the patches
        batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
        batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

        batch = self.linear(batch)
        cls = self.cls_token.expand([b, -1, -1])
        emb = batch + self.pos_emb

        return torch.cat([cls, emb], axis=1)

Normalize

Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch

Parameters:

Name Type Description Default
mean list or ndarray

Mean values for normalization.

required
std list or ndarray

Standard deviation values for normalization.

required
Source code in windscangeo\func_ml.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Normalize:

    """
    Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch 

    Args:
        mean (list or np.ndarray): Mean values for normalization.
        std (list or np.ndarray): Standard deviation values for normalization.
    """
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def __call__(self, x):
        return (x - self.mean) / self.std

ResNet

Bases: Module

A ResNet model for image classification or regression tasks. This model consists of an initial convolutional layer, followed by a series of residual blocks, and a fully connected layer for classification or regression. The number of residual blocks in each layer is specified by the layer_list parameter. taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py

Source code in windscangeo\Models.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class ResNet(nn.Module):

    """
    A ResNet model for image classification or regression tasks.
    This model consists of an initial convolutional layer, followed by a series
    of residual blocks, and a fully connected layer for classification or regression.
    The number of residual blocks in each layer is specified by the `layer_list` parameter.
    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    """
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)

        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*ResBlock.expansion, num_classes)

    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []

        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )

        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion

        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))

        return nn.Sequential(*layers)

ViT

Bases: Module

Vision Transformer (ViT) model for image classification or regression tasks. This model consists of an image-to-sequence layer, a transformer encoder, and a multi-layer perceptron (MLP) head for classification or regression.

Taken from # https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class ViT(nn.Module):
    """
    Vision Transformer (ViT) model for image classification or regression tasks.
    This model consists of an image-to-sequence layer, a transformer encoder,
    and a multi-layer perceptron (MLP) head for classification or regression.

    Taken from # https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    def __init__(
        self,
        img_size,
        patch_size,
        n_channels,
        d_model,
        nhead,
        dim_feedforward,
        blocks,
        mlp_head_units,
        n_classes,
    ):
        super().__init__()
        """
        Args:
            img_size: Size of the image
            patch_size: Size of the patch
            n_channels: Number of image channels
            d_model: The number of features in the transformer encoder
            nhead: The number of heads in the multiheadattention models
            dim_feedforward: The dimension of the feedforward network model in the encoder
            blocks: The number of sub-encoder-layers in the encoder
            mlp_head_units: The hidden units of mlp_head
            n_classes: The number of output classes
        """
        self.img2seq = Img2Seq(img_size, patch_size, n_channels, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, activation="gelu", batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, blocks
        )
        self.mlp = get_mlp(d_model, mlp_head_units, n_classes)

        self.output = nn.Identity() # For regression

    def forward(self, batch):

        batch = self.img2seq(batch)
        batch = self.transformer_encoder(batch)
        batch = batch[:, 0, :]
        batch = self.mlp(batch)
        output = self.output(batch)
        return output

conventional_dataset

Bases: Dataset

A PyTorch Dataset for loading data using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images.

required
targets list or ndarray

List or array of targets corresponding to the images.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class conventional_dataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images.
        targets (list or np.ndarray): List or array of targets corresponding to the images.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        # 2) Target for sample "idx"
        target = torch.tensor(self.targets[idx], dtype=torch.float32)

        return image, target

conventional_dataset_inference

Bases: Dataset

A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images to be used for inference.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class conventional_dataset_inference(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images to be used for inference.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images,transform=None):
        self.images = images
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        return image

buoy_data_extract(folder_path, polar_data, date)

Extracts buoy data from a specified folder and returns arrays of latitude, longitude, time, wind speed, and buoy names.

Parameters:

Name Type Description Default
folder_path str

Path to the folder containing buoy data files.

required
polar_data Dataset

Polar data containing latitude and longitude information. Used to snap buoy data to the nearest polar grid points.

required
date str

Date for which to extract buoy data, in 'YYYY-MM-DD' format.

required

Returns:

Name Type Description
buoy_lat ndarray

Array of buoy latitudes snapped to the nearest polar grid points.

buoy_lon ndarray

Array of buoy longitudes snapped to the nearest polar grid points

buoy_time ndarray

Array of buoy observation times.

buoy_wind_speed ndarray

Array of buoy wind speeds.

buoy_name ndarray

Array of buoy names.

Source code in windscangeo\func_inference.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def buoy_data_extract(folder_path, polar_data, date):
    """ Extracts buoy data from a specified folder and returns arrays of latitude, longitude, time, wind speed, and buoy names.

    Args:
        folder_path (str): Path to the folder containing buoy data files.
        polar_data (xarray.Dataset): Polar data containing latitude and longitude information. Used to snap buoy data to the nearest polar grid points.
        date (str): Date for which to extract buoy data, in 'YYYY-MM-DD' format.

    Returns:
        buoy_lat (np.ndarray): Array of buoy latitudes snapped to the nearest polar grid points.
        buoy_lon (np.ndarray): Array of buoy longitudes snapped to the nearest polar grid points
        buoy_time (np.ndarray): Array of buoy observation times.
        buoy_wind_speed (np.ndarray): Array of buoy wind speeds.
        buoy_name (np.ndarray): Array of buoy names.
    """
    buoy_lat = []
    buoy_lon = []
    buoy_wind_speed = []
    buoy_time = []
    buoy_name = []

    for file in os.listdir(folder_path):
        if ".cdf" in file:
            file_path = os.path.join(folder_path, file)
            opened = xr.open_dataset(file_path)
            lat, lon, time, wind_speed, name = form_arrays_buoy(opened, date)
            if np.sum(wind_speed) > 0:
                buoy_lat.extend(lat)
                buoy_lon.extend(lon)
                buoy_wind_speed.extend(wind_speed)
                buoy_time.append(time)
                buoy_name.append(name)

    buoy_lat = np.array(buoy_lat)
    buoy_lat = snap_to_nearest(buoy_lat, polar_data.latitude.values, cutoff=0.8)
    buoy_lon = np.array(buoy_lon)
    buoy_wind_speed = np.array(buoy_wind_speed)
    buoy_time = np.array(buoy_time)

    buoy_lon = np.where(buoy_lon > 180, buoy_lon - 360, buoy_lon)
    buoy_lon = snap_to_nearest(buoy_lon, polar_data.longitude.values, cutoff=0.8)

    return buoy_lat, buoy_lon, buoy_time, buoy_wind_speed, buoy_name

calculate_degrees(file_id)

This function calculates the latitude and longitude of the GOES ABI fixed grid projection. This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

Parameters:

Name Type Description Default
file_id Dataset

The xarray dataset containing the GOES ABI fixed grid projection variables.

required

Returns:

Name Type Description
abi_lat ndarray

The latitude of the GOES ABI fixed grid projection.

abi_lon ndarray

The longitude of the GOES ABI fixed grid projection.

Source code in windscangeo\func.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def calculate_degrees(file_id):
    """This function calculates the latitude and longitude of the GOES ABI fixed grid projection. 
    This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

    Args:
        file_id (xarray.Dataset): The xarray dataset containing the GOES ABI fixed grid projection variables.

    Returns:
        abi_lat (numpy.ndarray): The latitude of the GOES ABI fixed grid projection.
        abi_lon (numpy.ndarray): The longitude of the GOES ABI fixed grid projection.


    """

    # Read in GOES ABI fixed grid projection variables and constants
    x_coordinate_1d = file_id.variables["x"][:]  # E/W scanning angle in radians
    y_coordinate_1d = file_id.variables["y"][:]  # N/S elevation angle in radians
    projection_info = file_id.goes_imager_projection
    lon_origin = projection_info.longitude_of_projection_origin
    H = projection_info.perspective_point_height + projection_info.semi_major_axis
    r_eq = projection_info.semi_major_axis
    r_pol = projection_info.semi_minor_axis

    # Create 2D coordinate matrices from 1D coordinate vectors
    x_coordinate_2d, y_coordinate_2d = np.meshgrid(x_coordinate_1d, y_coordinate_1d)

    # Equations to calculate latitude and longitude
    lambda_0 = (lon_origin * np.pi) / 180.0
    a_var = np.power(np.sin(x_coordinate_2d), 2.0) + (
        np.power(np.cos(x_coordinate_2d), 2.0)
        * (
            np.power(np.cos(y_coordinate_2d), 2.0)
            + (
                ((r_eq * r_eq) / (r_pol * r_pol))
                * np.power(np.sin(y_coordinate_2d), 2.0)
            )
        )
    )
    b_var = -2.0 * H * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    c_var = (H**2.0) - (r_eq**2.0)
    r_s = (-1.0 * b_var - np.sqrt((b_var**2) - (4.0 * a_var * c_var))) / (2.0 * a_var)
    s_x = r_s * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    s_y = -r_s * np.sin(x_coordinate_2d)
    s_z = r_s * np.cos(x_coordinate_2d) * np.sin(y_coordinate_2d)

    # Ignore numpy errors for sqrt of negative number; occurs for GOES-16 ABI CONUS sector data
    np.seterr(all="ignore")

    abi_lat = (180.0 / np.pi) * (
        np.arctan(
            ((r_eq * r_eq) / (r_pol * r_pol))
            * ((s_z / np.sqrt(((H - s_x) * (H - s_x)) + (s_y * s_y))))
        )
    )
    abi_lon = (lambda_0 - np.arctan(s_y / (H - s_x))) * (180.0 / np.pi)

    print("INFO : Latitude and longitude calculated")
    return abi_lat, abi_lon

create_folder(experiment_name)

Create a folder for saving results based on the experiment name.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment to create a folder for.

required

Returns:

Name Type Description
str

Path to the created folder.

Source code in windscangeo\func.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def create_folder(experiment_name):
    """
    Create a folder for saving results based on the experiment name.

    Args:
        experiment_name (str): Name of the experiment to create a folder for.
        If the folder already exists, it will not be created again.

    Returns:
        str: Path to the created folder.
    """

    path_folder = f"./results_folder/model_day_{experiment_name}"

    if not os.path.exists(path_folder):
        os.makedirs(path_folder)
        print(f"Folder created at {path_folder}")

    return path_folder

early_stopping(valid_losses, patience_epochs, patience_loss)

Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

Parameters:

Name Type Description Default
valid_losses list

List of validation losses recorded during training.

required
patience_epochs int

Number of epochs to wait before stopping if no improvement.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required

Returns:

Name Type Description
bool

True if training should stop, False otherwise.

Source code in windscangeo\impl.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def early_stopping(valid_losses, patience_epochs, patience_loss):  # From @Jing
    """
    Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

    Args:
        valid_losses (list): List of validation losses recorded during training.
        patience_epochs (int): Number of epochs to wait before stopping if no improvement.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.

    Returns:
        bool: True if training should stop, False otherwise.
    """
    if len(valid_losses) < patience_epochs:
        return False
    recent_losses = valid_losses[-patience_epochs:]

    if all(x >= recent_losses[0] for x in recent_losses):
        return True

    if max(recent_losses) - min(recent_losses) < patience_loss:
        return True
    return False

error_plot(best_val_outputs, best_val_labels, path_folder=None)

Plot a scatter plot of model outputs vs true labels for the validation dataset.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the plot. If None, the plot will not be saved.

None
Source code in windscangeo\func_ml.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def error_plot(best_val_outputs, best_val_labels, path_folder=None):
    """
    Plot a scatter plot of model outputs vs true labels for the validation dataset.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        path_folder (str, optional): Path to save the plot. If None, the plot will not be saved.
    """

    max_all = max(max(best_val_outputs), max(best_val_labels))

    plt.figure(figsize=(5, 5))
    plt.plot(best_val_labels, best_val_outputs, "o")
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xlim(0, max_all)
    plt.ylim(0, max_all)
    plt.xlabel("True Labels")
    plt.ylabel("Model Output")
    plt.title("Model Output vs True Labels in test dataset")
    plt.xticks(np.arange(0, 30, 5))
    plt.yticks(np.arange(0, 30, 5))
    plt.plot(
        [min(best_val_labels), max(best_val_labels)],
        [min(best_val_labels), max(best_val_labels)],
        "r--",
    )  # y = x reference line
    if path_folder:
        plt.savefig(os.path.join(path_folder, "scatter_plot.png"))

extract_goes(observation_times, observation_lats, observation_lons, scatterometer_data_path, goes_aws_url_folder, goes_channel='C01', goes_image_size=128, verbose=True)

This function extracts GOES images for the given observation times, latitudes, and longitudes. It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
scatterometer_data_path str

The path to the scatterometer data directory.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored.

required
goes_channel str

The channel of interest. Default is "C01".

'C01'
goes_image_size int

The size of the output images. Default is 128.

128
verbose bool

If True, prints progress information.

True

Returns:

Name Type Description
images ndarray

A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

Source code in windscangeo\func.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def extract_goes(
    observation_times,
    observation_lats,
    observation_lons,
    scatterometer_data_path,
    goes_aws_url_folder,
    goes_channel="C01",
    goes_image_size=128,
    verbose=True,
):
    """
    This function extracts GOES images for the given observation times, latitudes, and longitudes.
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data. 
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        scatterometer_data_path (str): The path to the scatterometer data directory.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored.
        goes_channel (str): The channel of interest. Default is "C01".
        goes_image_size (int): The size of the output images. Default is 128.
        verbose (bool): If True, prints progress information.

    Returns:
        images (numpy.ndarray): A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

    """

    for file in os.listdir(scatterometer_data_path):
        if file.endswith(".nc"):
            polar = xr.open_dataset(
                os.path.join(scatterometer_data_path, file),
                engine="h5netcdf",
                drop_variables=["DQF"],
            )
            break

        else:
            print('WARNING : No .nc file found in the scatterometer data path, please check the path')

    template_scatter = polar.isel(time=0)
    lat_grd, lon_grd = (
        template_scatter["latitude"].values,
        template_scatter["longitude"].values,
    )

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)

    values, counts = np.unique(observation_times, return_counts=True)

    all_urls = []  # getting unique URLS
    for value in values:
        urls = get_goes_url(value, goes_aws_url_folder,goes_channel)
        all_urls.append(urls)

    values_url, indices_url, counts_url = np.unique(
        all_urls, return_index=True, return_counts=True, axis=0
    )
    # Sort indices to "unsort" the URLs
    sorted_indices = sorted(range(len(indices_url)), key=lambda k: indices_url[k])
    values_url = [all_urls[indices_url[i]] for i in sorted_indices]

    # Reorder counts_url using the same sorted indices
    counts_url = [counts_url[i] for i in sorted_indices]

    compressed_urls = values_url
    compressed_counts = []
    start_idx = 0

    for size in counts_url:
        group_sum = counts[start_idx : start_idx + size].sum()
        compressed_counts.append(group_sum)
        start_idx += size

    width = goes_image_size
    height = goes_image_size

    images = np.zeros([len(observation_times), 1 , width, height], dtype=np.float32)

    total_idx = 0
    for unique_idx, unique_urls in tqdm(
        enumerate(compressed_urls),
        desc="INFO : Retrieving and processing GOES data",
        total=len(compressed_urls),
        disable=not verbose,
    ):


        for CH_idx, url_CH in enumerate(unique_urls):

            if url_CH == 0:
                images[total_idx, CH_idx] = np.zeros([width, height])
                continue

            with fs.open(url_CH, mode="rb") as f:

                ds = xr.open_dataset(
                    f, engine="h5netcdf", drop_variables=["DQF"]
                )  # this is the bottleneck

                parallel_index = index_parallel(
                    ds,
                    template_scatter,
                )
                for i in range(compressed_counts[unique_idx]):
                    images[total_idx + i, CH_idx] = get_image(
                        ds=ds,
                        parallel_index=parallel_index,
                        lat_grd=lat_grd,
                        lon_grd=lon_grd,
                        lat_search=observation_lats[total_idx + i],
                        lon_search=observation_lons[total_idx + i],
                        goes_image_size=goes_image_size,
                    )

        total_idx += compressed_counts[unique_idx]

    if verbose:
        print(
            f"INFO : Extracted {len(observation_times)} images from {len(compressed_urls)} GOES files."
        )
    return images

extract_goes_inference(date_time, parallel_index, channels='C01', goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF')

This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.) It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size (128x128).

Parameters:

Name Type Description Default
date_time datetime64

The time of the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
channels str or list

The channel(s) of interest. Default is "C01".

'C01'
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

'noaa-goes16/ABI-L2-CMIPF'

Returns:

Name Type Description
images list

A list of numpy arrays containing the extracted GOES images of shape (128, 128).

Source code in windscangeo\func.py
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def extract_goes_inference(date_time, parallel_index,channels="C01",goes_aws_url_folder= 'noaa-goes16/ABI-L2-CMIPF'):
    """
    This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.)
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create
    images of the specified size (128x128).

    Args:
        date_time (numpy.datetime64): The time of the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        channels (str or list): The channel(s) of interest. Default is "C01".
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

    Returns:
        images (list): A list of numpy arrays containing the extracted GOES images of shape (128, 128).
    """

    # ignore divide by zero errors which occur when the GOES data can't form a 128x128 image
    np.seterr(invalid='ignore', divide='ignore')

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)
    urls = get_goes_url(date_time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel= channels)
    with fs.open(urls[0], mode="rb") as f:
        print("INFO : Reading file:", urls[0])
        goes_image = xr.open_dataset(f)
        goes_image = goes_image.rename({"x": "x_index", "y": "y_index"})

        # Assign the index coordinates (if not already done)
        goes_image = goes_image.assign_coords(
            x_index=np.arange(goes_image.sizes["x_index"]),
            y_index=np.arange(goes_image.sizes["y_index"]),
        )

        images = []
        goes_image.load()
        print("INFO : Extracting images")
        for i in range(parallel_index.shape[0]):
            for j in range(parallel_index.shape[1]):
                try:
                    x_mean = parallel_index[i][j][1].mean().astype(int)
                    x_min = x_mean - 63
                    x_max = x_mean + 63
                    y_mean = parallel_index[i][j][0].mean().astype(int)
                    y_min = y_mean - 63
                    y_max = y_mean + 63
                    image = goes_image.CMI.sel(
                        x_index=slice(x_min, x_max), y_index=slice(y_min, y_max)
                    )

                    target_size = (128, 128)

                    padded_image = np.pad(
                        image,
                        (
                            (
                                (target_size[0] - image.shape[0]) // 2,
                                (target_size[0] - image.shape[0] + 1) // 2,
                            ),
                            (
                                (target_size[1] - image.shape[1]) // 2,
                                (target_size[1] - image.shape[1] + 1) // 2,
                            ),
                        ),
                        constant_values=0,
                    )

                except:
                    images.append(np.zeros((128, 128)))

                    continue
                images.append(padded_image)

        return images

extract_goes_production(time_choice, polar_data, parallel_index, channels, goes_aws_url_folder)

Extracts GOES data for a specific time from the polar data and returns the images along with valid latitudes, longitudes, and times.

Parameters:

Name Type Description Default
time_choice str

The time for which to extract GOES data, in 'YYYY-MM-DD HH:MM:SS' format.

required
polar_data Dataset

Polar data containing latitude and longitude information. Used to create a grid of valid latitudes and longitudes.

required
parallel_index int

Index for parallel processing, used to identify the specific GOES data to extract, generated by the index_parallel function.

required
channels list

List of GOES channels to extract.

required
goes_aws_url_folder str

AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".

required

Returns:

Name Type Description
images ndarray

Array of extracted GOES images for the specified time.

valid_lats ndarray

Array of valid latitudes corresponding to the GOES images

valid_lons ndarray

Array of valid longitudes corresponding to the GOES images

valid_times ndarray

Array of valid times corresponding to the GOES images.

Source code in windscangeo\func_inference.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def extract_goes_production(time_choice, polar_data, parallel_index,channels,goes_aws_url_folder):

    """ 
    Extracts GOES data for a specific time from the polar data and returns the images along with valid latitudes, longitudes, and times.

    Args:
        time_choice (str): The time for which to extract GOES data, in 'YYYY-MM-DD HH:MM:SS' format.
        polar_data (xarray.Dataset): Polar data containing latitude and longitude information. Used to create a grid of valid latitudes and longitudes.
        parallel_index (int): Index for parallel processing, used to identify the specific GOES data to extract, generated by the `index_parallel` function.
        channels (list): List of GOES channels to extract.
        goes_aws_url_folder (str): AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".

    Returns:
        images (np.ndarray): Array of extracted GOES images for the specified time.
        valid_lats (np.ndarray): Array of valid latitudes corresponding to the GOES images
        valid_lons (np.ndarray): Array of valid longitudes corresponding to the GOES images
        valid_times (np.ndarray): Array of valid times corresponding to the GOES images.

    """
    time_formated = (
        np.datetime64(time_choice).astype("datetime64[ns]").astype("float64")
    )

    longrid, latgrid = np.meshgrid(polar_data["longitude"], polar_data["latitude"])
    lon_array = longrid.flatten()
    lat_array = latgrid.flatten()
    time_array = np.full_like(lon_array, time_formated)

    valid_lons = lon_array
    valid_lats = lat_array
    valid_times = time_array

    print('INFO : Extracting GOES data')
    images = extract_goes_inference(np.datetime64(time_choice), parallel_index,channels,goes_aws_url_folder)


    images = np.expand_dims(images, axis=1)


    return images, valid_lats, valid_lons, valid_times

extract_matching_orbits(scatterometer_data_path, date, lat_range=[-90, 90], lon_range=[-180, 180], goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel='C01', goes_image_size=128, verbose=True, save=True)

Extracts matching orbits from scatterometer data (pre-downloaded) and GOES images (automatically extracted) for a given date. The function filters the data to only include daylight observations and returns the images and numerical data in a dictionary format.

Parameters:

Name Type Description Default
scatterometer_data_path str

Path to the pre-downloaded scatterometer data. Data must be in .netCDF format. See tutorial for downloading data.

required
date str

Date for which to extract the data in 'YYYY-MM-DD' format.

required
lat_range tuple

Latitude range to filter the data, default is [-90,180] (all latitudes).

[-90, 90]
lon_range tuple

Longitude range to filter the data, default is [-180, 180] (all longitudes).

[-180, 180]
goes_aws_url_folder str

AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".

'noaa-goes16/ABI-L2-CMIPF'
goes_channel str

GOES channel to extract, default is "C01" (visible).

'C01'
goes_image_size int

Size of the GOES images to extract, default is 128.

128
verbose bool

If True, prints additional information during the extraction process, default is True.

True
save bool

If True, saves the preloaded data to a compressed .npz file, default is True.

True

Returns:

Name Type Description
images list

List of filtered GOES images.

numerical_data dict

Dictionary containing filtered numerical data (latitudes, longitudes, times, wind speeds).

saved_file_path str

Path to the saved .npz file if save is True, otherwise None.

Source code in windscangeo\main_func.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def extract_matching_orbits(scatterometer_data_path: str,
                            date : str ,
                            lat_range : tuple = [-90, 90],
                            lon_range : tuple = [-180, 180],
                            goes_aws_url_folder : str = "noaa-goes16/ABI-L2-CMIPF",
                            goes_channel : str = "C01",
                            goes_image_size : int = 128,
                            verbose : bool = True,
                            save : bool =True
                            ):

    """    Extracts matching orbits from scatterometer data (pre-downloaded) and GOES images (automatically extracted) for a given date. 
    The function filters the data to only include daylight observations and returns the images and numerical data in a dictionary format.

    Args:
        scatterometer_data_path (str): Path to the pre-downloaded scatterometer data. Data must be in .netCDF format. See tutorial for downloading data.
        date (str): Date for which to extract the data in 'YYYY-MM-DD' format.
        lat_range (tuple): Latitude range to filter the data, default is [-90,180] (all latitudes).
        lon_range (tuple): Longitude range to filter the data, default is [-180, 180] (all longitudes).
        goes_aws_url_folder (str): AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".
        goes_channel (str): GOES channel to extract, default is "C01" (visible).
        goes_image_size (int): Size of the GOES images to extract, default is 128.
        verbose (bool): If True, prints additional information during the extraction process, default is True.
        save (bool): If True, saves the preloaded data to a compressed .npz file, default is True.

    Returns:
        images (list): List of filtered GOES images.
        numerical_data (dict): Dictionary containing filtered numerical data (latitudes, longitudes, times, wind speeds).
        saved_file_path (str): Path to the saved .npz file if save is True, otherwise None.

    """
    if verbose:
        print("START : Extracting matching orbits for date:", date) 
        print("INFO : lat_range:", lat_range,) 
        print("INFO : lon_range:", lon_range,) 
        print("INFO : goes_channel:", goes_channel, 'at url:', goes_aws_url_folder)

    # Extract scatterometer data in correct format

    (
        observation_times_local,
        observation_lats_local,
        observation_lons_local,
        observation_wind_speeds_local,
    ) = extract_scatter_multisat(
        scatterometer_data_path, date, lat_range, lon_range,verbose=verbose)


    # filter nighttime data
    valid_times, valid_lats, valid_lons, valid_wind_speeds = filter_nighttime(
        observation_times_local,
        observation_lats_local,
        observation_lons_local,
        observation_wind_speeds_local,
        verbose=verbose
        )

    # Extract GOES data from matching orbits 
    # If this is first run with dataset, it will create a folder with indices which takes time.

    images = extract_goes(  
        observation_times=valid_times[0:100], #TODO: REMOVE THIS LIMITATION, only for debugging
        observation_lats=valid_lats[0:100],
        observation_lons=valid_lons[0:100],
        scatterometer_data_path=scatterometer_data_path,
        goes_aws_url_folder=goes_aws_url_folder,
        goes_channel=goes_channel,
        goes_image_size=goes_image_size,
        verbose=verbose
        )

    # Package the data into a dictionary for easy access and logging. The observation_wind_speeds is the target variable.
    numerical_data = {
        "observation_lats": np.array(valid_lats),
        "observation_lons": np.array(valid_lons),
        "observation_times": np.array(valid_times),
        "observation_wind_speeds": np.array(valid_wind_speeds),
    }

    # Filter the images and numerical data to remove nans, invalid images

    images_filtered, numerical_data_filtered = package_data(
        images, numerical_data, solar_conversion=False, verbose=verbose
    )

    if verbose:
        print("END : Extracted :", len(images_filtered)," training pairs")

    # Save the preloaded data to a compressed .npz file for later use
    if save:

        # check if folder exists
        if not os.path.exists("./saved_files/"):
            os.makedirs("./saved_files/")

        channels_safe_name = goes_channel
        saved_file_path = f"./saved_files/file_preloaded_{date}_{channels_safe_name}.npz"
        print("INFO : Saving preloaded data to file : ", saved_file_path)

        np.savez_compressed(
            saved_file_path,
            images=images_filtered,
            numerical_data=numerical_data_filtered,
        )

        return images_filtered, numerical_data_filtered, saved_file_path


    return images_filtered, numerical_data_filtered

extract_scatter(polar_data, date, lat_range, lon_range, verbose=True, main_variable='wind_speed')

This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range. The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

Parameters:

Name Type Description Default
polar_data Dataset

The scatterometer dataset (ASCAT, HYSCAT etc).

required
date datetime64

The time of the scatterometer data.

required
lat_range tuple

The latitude range of the scatterometer data.

required
lon_range tuple

The longitude range of the scatterometer data.

required
verbose bool

If True, the function will print the progress of the extraction.

True
main_variable str

The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

'wind_speed'

Returns:

Name Type Description
observation_times ndarray

The time of observation of the scatterometer data.

observation_lats ndarray

The latitude of the scatterometer data.

observation_lons ndarray

The longitude of the scatterometer data.

observation_main_parameter ndarray

main parameter extracted (wind_speed).

Source code in windscangeo\func.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def extract_scatter(
    polar_data,
    date,
    lat_range,
    lon_range,
    verbose=True,
    main_variable="wind_speed",
):
    """
    This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range.
    The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

    Args:
        polar_data (xarray.Dataset): The scatterometer dataset (ASCAT, HYSCAT etc).
        date (numpy.datetime64): The time of the scatterometer data.
        lat_range (tuple): The latitude range of the scatterometer data.
        lon_range (tuple): The longitude range of the scatterometer data.
        verbose (bool): If True, the function will print the progress of the extraction.
        main_variable (str): The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

    Returns:
        observation_times (numpy.ndarray): The time of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitude of the scatterometer data.
        observation_lons (numpy.ndarray): The longitude of the scatterometer data.
        observation_main_parameter (numpy.ndarray): main parameter extracted (wind_speed).

    """

    polar = polar_data.sel(
        time=slice(date, date),
        latitude=slice(lat_range[0], lat_range[1]),
        longitude=slice(lon_range[0], lon_range[1]),
    )

    seperated_scatter = savedataseperated(polar, polar[main_variable],verbose=verbose)

    observation_times = seperated_scatter[2]
    observation_lats = seperated_scatter[0]
    observation_lons = seperated_scatter[1]
    observation_wind_speeds = seperated_scatter[3]



    return (
        observation_times,
        observation_lats,
        observation_lons,
        observation_wind_speeds,
    )

extract_scatter_multisat(scatterometer_data_path, date, lat_range, lon_range, verbose=True)

Extracts scatterometer data from multiple files (.nc) in a specified directory.

Parameters:

Name Type Description Default
scatterometer_data_path str

Path to the directory containing scatterometer data files.

required
date datetime

Date for which to extract data.

required
lat_range tuple

Latitude range (min, max) for filtering data.

required
lon_range tuple

Longitude range (min, max) for filtering data.

required
verbose bool

If True, prints progress information.

True

Returns:

Name Type Description
tuple

A tuple containing: - list of datetime: observation times - list of float: latitudes - list of float: longitudes - list of float: wind speeds

Source code in windscangeo\func_inference.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def extract_scatter_multisat(
    scatterometer_data_path, date, lat_range, lon_range,verbose=True
):

    """
    Extracts scatterometer data from multiple files (`.nc`) in a specified directory.

    Args:
        scatterometer_data_path (str): Path to the directory containing scatterometer data files.
        date (datetime): Date for which to extract data.
        lat_range (tuple): Latitude range (min, max) for filtering data.
        lon_range (tuple): Longitude range (min, max) for filtering data.
        verbose (bool): If True, prints progress information.

    Returns:
        tuple: A tuple containing:
            - list of datetime: observation times
            - list of float: latitudes
            - list of float: longitudes
            - list of float: wind speeds
    """

    observation_times = []
    observation_lats = []
    observation_lons = []
    observation_wind_speeds = []

    if verbose:
        print("INFO : Extracting scatterometer data from folder : ", scatterometer_data_path)
        print("___")

    for file in os.listdir(scatterometer_data_path):
        if ".nc" in file:
            # Open the file
            file_path = scatterometer_data_path + file
            polar_data = xr.open_dataset(file_path)
            (
                observation_times_local,
                observation_lats_local,
                observation_lons_local,
                observation_wind_speeds_local,
            ) = extract_scatter(
                polar_data, date, lat_range, lon_range, verbose=verbose
            )
            observation_times.extend(observation_times_local)
            observation_lats.extend(observation_lats_local)
            observation_lons.extend(observation_lons_local)
            observation_wind_speeds.extend(observation_wind_speeds_local)

    if verbose : 
        print("___")
        print(f"INFO : Total number of scatterometer data points: {len(observation_times)}")
    return (
        observation_times,
        observation_lats,
        observation_lons,
        observation_wind_speeds,
    )

fill_nans(images)

This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required

Returns:

Name Type Description
images ndarray

A 4D numpy array with NaN values replaced by zeros.

Source code in windscangeo\func.py
808
809
810
811
812
813
814
815
816
817
818
819
820
def fill_nans(images):
    """
    This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

    Returns:
        images (numpy.ndarray): A 4D numpy array with NaN values replaced by zeros.
    """
    images = np.nan_to_num(images, nan=0.0)
    print("INFO : Filled nans")
    return images

filter_invalid(images, numerical_data, min_nonzero_pixels=50)

This function filters out invalid images and corresponding numerical data based on two criteria: 1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty). 2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required
numerical_data dict

A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.

required
min_nonzero_pixels int

The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

50

Returns:

Name Type Description
filtered_images ndarray

A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.

filtered_numerical_data dict

A dictionary containing the numerical data associated with the valid images.

Source code in windscangeo\func.py
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
def filter_invalid(
    images,
    numerical_data,
    min_nonzero_pixels=50,
):

    """
    This function filters out invalid images and corresponding numerical data based on two criteria:
    1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty).
    2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.
        numerical_data (dict): A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.
        min_nonzero_pixels (int): The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

    Returns:
        filtered_images (numpy.ndarray): A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.
        filtered_numerical_data (dict): A dictionary containing the numerical data associated with the valid images.

    """
    # Sums of pixel values in each image
    sums_images = [np.nansum(x) for x in images]

    # Counts of non-zero pixels in each image
    nonzero_counts = [np.count_nonzero(x) for x in images]

    # Build a "mask_invalid" array of indices that fail any criterion:
    # 1) sum == 0 (completely empty)
    # 2) nonzero pixel count < min_nonzero_pixels (not enough data)

    mask_valid = np.where(
        (np.array(sums_images) != 0) & (np.array(nonzero_counts) >= min_nonzero_pixels)
    )[0]

    # Delete the invalid entries from each array
    filtered_numerical_data = {
        key: value[mask_valid] for key, value in numerical_data.items()
    }
    filtered_images = images[mask_valid]
    n_removed_images = len(images) - len(filtered_images)

    print(
        "INFO : Filtered invalid images. Removed {} entries.".format(
            n_removed_images
        )
    )
    return (
        filtered_images,
        filtered_numerical_data,
    )

filter_nighttime(observation_times, observation_lats, observation_lons, observation_wind_speeds, min_hour=10, max_hour=19, verbose=True)

This function filters the scatterometer data to only include observations that were made during daylight hours. The function checks the hour of each observation time and only keeps those that fall within the specified range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
observation_wind_speeds ndarray

The wind speeds of the scatterometer data.

required
min_hour int

The minimum hour of the day to include (default is 10).

10
max_hour int

The maximum hour of the day to include (default is 19).

19
verbose bool

If True, prints the number of valid scatterometer data points at daylight.

True

Returns:

Name Type Description
valid_times list

A list of valid observation times that fall within the specified hour range.

valid_lats list

A list of valid latitudes corresponding to the valid observation times

valid_lons list

A list of valid longitudes corresponding to the valid observation times.

valid_wind_speeds list

A list of valid wind speeds corresponding to the valid observation times.

Source code in windscangeo\func.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def filter_nighttime(
    observation_times,
    observation_lats,
    observation_lons,
    observation_wind_speeds,
    min_hour=10,
    max_hour=19,
    verbose=True,
):
    """
    This function filters the scatterometer data to only include observations that were made during daylight hours.
    The function checks the hour of each observation time and only keeps those that fall within the specified
    range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        observation_wind_speeds (numpy.ndarray): The wind speeds of the scatterometer data.
        min_hour (int): The minimum hour of the day to include (default is 10).
        max_hour (int): The maximum hour of the day to include (default is 19).
        verbose (bool): If True, prints the number of valid scatterometer data points at daylight.

    Returns:
        valid_times (list): A list of valid observation times that fall within the specified hour range.
        valid_lats (list): A list of valid latitudes corresponding to the valid observation times
        valid_lons (list): A list of valid longitudes corresponding to the valid observation times.
        valid_wind_speeds (list): A list of valid wind speeds corresponding to the valid observation times.

    """

    valid_times = []
    valid_lats = []
    valid_lons = []
    valid_wind_speeds = []

    for idx in range(len(observation_times)):
        only_hour = int(
            observation_times[idx].astype("datetime64[ns]").astype("str")[11:13]
        )
        if min_hour <= only_hour <= max_hour:
            valid_times.append(observation_times[idx])
            valid_lats.append(observation_lats[idx])
            valid_lons.append(observation_lons[idx])
            valid_wind_speeds.append(observation_wind_speeds[idx])

    if verbose:
        print(f"INFO : Total number of scatterometer data points at daylight : {len(valid_times)}")
    return valid_times, valid_lats, valid_lons, valid_wind_speeds

form_arrays_buoy(buoy, date_choice)

Form arrays from buoy data for a specific date.

Parameters:

Name Type Description Default
buoy Dataset

Buoy data containing time, latitude, longitude, and wind speed.

required
date_choice str

Date for which to extract buoy data, in 'YYYY-MM

required

Returns:

Name Type Description
lat ndarray

Array of buoy latitudes.

lon ndarray

Array of buoy longitudes.

time ndarray

Array of buoy observation times in nanoseconds since epoch.

wind_speed ndarray

Array of buoy wind speeds.

buoy_name ndarray

Array of buoy names.

Source code in windscangeo\func_inference.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def form_arrays_buoy(buoy, date_choice):

    """
    Form arrays from buoy data for a specific date.

    Args:
        buoy (xarray.Dataset): Buoy data containing time, latitude, longitude, and wind speed.
        date_choice (str): Date for which to extract buoy data, in 'YYYY-MM

    Returns:
        lat (np.ndarray): Array of buoy latitudes.
        lon (np.ndarray): Array of buoy longitudes.
        time (np.ndarray): Array of buoy observation times in nanoseconds since epoch.
        wind_speed (np.ndarray): Array of buoy wind speeds.
        buoy_name (np.ndarray): Array of buoy names.

    """
    try:
        start = np.datetime64(date_choice) - np.timedelta64(5, "m")
        end = np.datetime64(date_choice) + np.timedelta64(5, "m")

        wind_speed = buoy.sel(time=slice(start, end)).WS_401.values.flatten()
        time = (
            buoy.sel(time=slice(start, end))
            .time.values.astype("datetime64[ns]")
            .astype("int64")
        )
        lat = np.full(len(wind_speed), buoy.lat.values)
        lon = np.full(len(wind_speed), buoy.lon.values)
        buoy_name = np.full(len(wind_speed), buoy.platform_code, dtype=object)
    except:
        wind_speed = np.array([])
        time = np.array([])
        lat = np.array([])
        lon = np.array([])
        buoy_name = np.array([])

        print("date selection unavailable")
    return lat, lon, time, wind_speed, buoy_name

get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel='C01')

This function gets the nearest GOES-16 files from the time given. The function returns a list of urls to the files. The function uses the s3fs library to access the AWS GOES-16 data.

Parameters:

Name Type Description Default
time datetime[ns]

The time of the scatterometer data.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES-16 data is stored.

'noaa-goes16/ABI-L2-CMIPF'
goes_channel list

The channel of interest.

'C01'

Returns:

Name Type Description
urls list

A list of urls to the GOES-16 files.

Source code in windscangeo\func.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel="C01"):
    """
    This function gets the nearest GOES-16 files from the time given.
    The function returns a list of urls to the files.
    The function uses the s3fs library to access the AWS GOES-16 data.

    Args:
        time (numpy.datetime[ns]): The time of the scatterometer data.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES-16 data is stored.
        goes_channel (list): The channel of interest.

    Returns:
        urls (list): A list of urls to the GOES-16 files.


    """
    date_c = time.astype("datetime64[ns]")
    date = pd.to_datetime(date_c)
    date_str = date.strftime("%Y/%j/%H")
    min = int(date.strftime("%M"))
    min_range = [(min + i) % 60 for i in range(-6, 7)]
    min_range_str = [f"{x:02d}" for x in min_range]
    fs = s3fs.S3FileSystem(anon=True)
    # get the nearest goes file from time

    urls = []
    channel = goes_channel
    path = f"{goes_aws_url_folder}/{date_str}"
    files = fs.ls(path)
    filter_channel = [x for x in files if channel in x]
    if len(filter_channel) == 0:
        print(f"INFO :No file found for {channel} on day {date_str}, skipping file")
        return
    file = [x for x in filter_channel if x[73:75] in min_range_str]
    if len(file) == 0:
        print(
            f"INFO :No file found for {channel} on day {date_str} for minute {min}, skipping file"
        )
        return np.zeros(len(goes_channel))
    urls.append("s3://" + file[0])

    return urls

get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search, goes_image_size=128)

This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 .nc file.

Parameters:

Name Type Description Default
ds Dataset

The xarray dataset containing the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required
goes_image_size int

The size of the output image. Default is 128.

128

Returns:

Name Type Description
padded_image DataArray

A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

Source code in windscangeo\func.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search,goes_image_size=128):

    """
    This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 `.nc` file.

    Args:
        ds (xarray.Dataset): The xarray dataset containing the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.
        goes_image_size (int): The size of the output image. Default is 128.

    Returns:
        padded_image (xarray.DataArray): A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

    """
    index_row = np.where(
        lat_grd == lat_search,
    )
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    pixels_from_center = (goes_image_size-1) // 2
    mean_row = rows_goes.mean().astype(int)
    min_row = mean_row - pixels_from_center
    max_row = mean_row + pixels_from_center

    mean_col = columns_goes.mean().astype(int)
    min_col = mean_col - pixels_from_center
    max_col = mean_col + pixels_from_center

    if "CMI" in ds: # If using GOES-16 L2 processed data
        image = ds.CMI[min_row:max_row, min_col:max_col].values

    elif "Rad" in ds: #If using GOES-16 L1b data
        image = ds.Rad[min_row:max_row, min_col:max_col].values

    # debug
    # print(min_row,'= min_row', max_row,'= max_row', min_col, '= min_col', max_col, '= max_col')
    target_size = (goes_image_size, goes_image_size)

    padded_image = np.pad(
        image,
        (
            (
                (target_size[0] - image.shape[0]) // 2,
                (target_size[0] - image.shape[0] + 1) // 2,
            ),
            (
                (target_size[1] - image.shape[1]) // 2,
                (target_size[1] - image.shape[1] + 1) // 2,
            ),
        ),
        constant_values=0,
    )

    padded_image = xr.DataArray(padded_image, dims=("x", "y"))
    return padded_image

get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125)

Finds the corresponding GOES row and column indices for each scatterometer point using a BallTree for efficiency, and then filtering points to form a square bounding box.

Parameters:

Name Type Description Default
lat_grid ndarray

2D array of latitudes from the scatterometer data.

required
lon_grid ndarray

2D array of longitudes from the scatterometer data.

required
Goeslat ndarray

2D array of latitudes from the GOES data.

required
Goeslon ndarray

2D array of longitudes from the GOES data.

required
radius float

Radius in degrees to define the bounding box around each scatterometer point.

0.125

Returns: indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.

Source code in windscangeo\func.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125):
    """
    Finds the corresponding GOES row and column indices for each scatterometer point
    using a BallTree for efficiency, and then filtering points to form a square bounding box.

    Args:
        lat_grid (numpy.ndarray): 2D array of latitudes from the scatterometer data.
        lon_grid (numpy.ndarray): 2D array of longitudes from the scatterometer data.
        Goeslat (numpy.ndarray): 2D array of latitudes from the GOES data.
        Goeslon (numpy.ndarray): 2D array of longitudes from the GOES data.
        radius (float): Radius in degrees to define the bounding box around each scatterometer point.
    Returns:
        indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.


    """

    print("INFO : Calculating indices")
    # Flatten GOES data
    Goeslat_flat = Goeslat.flatten()
    Goeslon_flat = Goeslon.flatten()
    goes_points = np.column_stack((Goeslat_flat, Goeslon_flat))

    # Build BallTree with haversine distance
    goes_points_rad = np.radians(goes_points)
    goes_tree = BallTree(goes_points_rad, metric="haversine")

    # Flatten scatter grids
    lat_flat = lat_grid.flatten()
    lon_flat = lon_grid.flatten()
    scatter_points = np.column_stack((lat_flat, lon_flat))
    scatter_points_rad = np.radians(scatter_points)

    # Radius for broad-phase query: diagonal of the bounding box
    # Square box ±radius: diagonal = radius * sqrt(2)
    diag_radius = radius * np.sqrt(2)
    diag_radius_rad = np.radians(diag_radius)

    indices_array = np.empty(lat_flat.shape, dtype=object)
    goes_shape = Goeslat.shape

    for i, (lat_val, lon_val) in enumerate(zip(lat_flat, lon_flat)):
        # Broad-phase: query all points within diagonal distance
        candidate_indices = goes_tree.query_radius(
            np.array([scatter_points_rad[i]]), r=diag_radius_rad
        )[0]

        if candidate_indices.size == 0:
            # No points found, store empty
            indices_array[i] = (np.array([], dtype=int), np.array([], dtype=int))
            continue

        # Post-filter candidates to keep only those in the bounding box
        lat_min = lat_val - radius
        lat_max = lat_val + radius
        lon_min = lon_val - radius
        lon_max = lon_val + radius

        cand_lats = Goeslat_flat[candidate_indices]
        cand_lons = Goeslon_flat[candidate_indices]

        mask = (
            (cand_lats >= lat_min)
            & (cand_lats <= lat_max)
            & (cand_lons >= lon_min)
            & (cand_lons <= lon_max)
        )

        final_indices = candidate_indices[mask]

        # Convert these flat indices back to row,col
        rows, cols = np.unravel_index(final_indices, goes_shape)
        indices_array[i] = (rows, cols)

    # Reshape indices_array to the original shape
    indices_array = indices_array.reshape(lat_grid.shape)
    return indices_array

get_mlp(in_features, hidden_units, out_features)

Returns a MLP head

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_mlp(in_features, hidden_units, out_features):
    """
    Returns a MLP head

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    dims = [in_features] + hidden_units + [out_features]
    layers = []
    for dim1, dim2 in zip(dims[:-2], dims[1:-1]):
        layers.append(nn.Linear(dim1, dim2))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    return nn.Sequential(*layers)

goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search)

This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

Parameters:

Name Type Description Default
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required

Returns:

Name Type Description
min_row int

The minimum row index of the GOES image.

max_row int

The maximum row index of the GOES image.

min_col int

The minimum column index of the GOES image.

max_col int

The maximum column index of the GOES image.

Source code in windscangeo\func.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
def goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search):
    """
    This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

    Args:
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.

    Returns:
        min_row (int): The minimum row index of the GOES image.
        max_row (int): The maximum row index of the GOES image.
        min_col (int): The minimum column index of the GOES image.
        max_col (int): The maximum column index of the GOES image.
    """

    index_row = np.where(lat_grd == lat_search)
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    min_row = rows_goes.min()
    max_row = rows_goes.max()

    min_col = columns_goes.min()
    max_col = columns_goes.max()

    return min_row, max_row, min_col, max_col

index_parallel(ds, ScatterDataset)

Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

Parameters:

Name Type Description Default
ScatterDataset

xarray Dataset containing scatterometer data.

required
scatter_name

Name for the output file.

required
output_path

Path to save the output file.

required

Returns:

Name Type Description
parallel_indice_values

2D array of tuples containing GOES row and column indices corresponding to scatterometer data.

Source code in windscangeo\func.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def index_parallel(ds, ScatterDataset):
    """
    Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

    Args:
        ScatterDataset: xarray Dataset containing scatterometer data.
        scatter_name: Name for the output file.
        output_path: Path to save the output file.

    Returns:
        parallel_indice_values: 2D array of tuples containing GOES row and column indices corresponding to scatterometer data.
    """

    create_folder("satellite_indices")
    ds_spatial_resolution = ds.spatial_resolution
    ds_spatial_resolution.replace(" ", "_")

    name_str = f"lat_{ScatterDataset.latitude.min().values}_{ScatterDataset.latitude.max().values}_lon_{ScatterDataset.longitude.min().values}_{ScatterDataset.longitude.max().values}_res_{ds_spatial_resolution}"
    name_str = name_str.replace(".", "_")
    if os.path.exists(
        f"./satellite_indices/{ds_spatial_resolution}_index.npy"
    ):
        parallel_index = np.load(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            allow_pickle=True,
        )

        return parallel_index

    else:
        print(
            "INFO : Satellite index file not found, creating new index file. This might take a while."
        )

        # Extract scatterometer latitudes and longitudes
        Latitudes_Scatter = ScatterDataset["latitude"].values
        Longitudes_Scatter = ScatterDataset["longitude"].values

        # Create a meshgrid of scatterometer coordinates
        lon_grid, lat_grid = np.meshgrid(Longitudes_Scatter, Latitudes_Scatter)

        # Extract GOES latitudes and longitudes
        Goeslat, Goeslon = calculate_degrees(ds)
        Goeslat[np.isnan(Goeslat)] = 999
        Goeslon[np.isnan(Goeslon)] = 999
        # Use the optimized get_indices function
        parallel_indice_values = get_indices(lat_grid, lon_grid, Goeslat, Goeslon)

        # Save the indices array
        np.save(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            parallel_indice_values,
        )

        return parallel_indice_values

inference_full_goes_image(datetime, scatterometer_data_path, result_path_folder, model_parameters, buoy_path, normalization_factors, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel='C01')

Runs the inference on a full GOES image for a given datetime using the pre-trained model from the function train_test_model.

Parameters:

Name Type Description Default
datetime str

Date and time in 'YYYY-MM-DD HH:MM:SS'

required
scatterometer_data_path str

Path to the pre-downloaded scatterometer data. Must be the same used in for extracting the matching orbits.

required
result_path_folder str

Path to the folder where the results of the model training are

required
model_parameters dict

Dictionary containing model parameters such as batch size, image size, learning rate, etc. Must be identical as those used in train_test_model.

required
buoy_path str

Path to the buoy data folder. Must be a folder containing the buoy data in .nc format. More details about how to download the buoy data can be found in the tutorial.

required
normalization_factors dict

Dictionary containing normalization factors (mean and std) for the images dataset.

required
goes_aws_url_folder str

AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF". Must be the same as used in the matching orbits extraction.

'noaa-goes16/ABI-L2-CMIPF'
goes_channel str

GOES channel to extract, default is "C01" (visible).

'C01'

Returns:

Type Description

None, but saves the inference results in a folder named inference_<datetime> in the result_path_folder.

Source code in windscangeo\main_func.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def inference_full_goes_image(
    datetime: str,
    scatterometer_data_path : str,
    result_path_folder : str,
    model_parameters : dict,
    buoy_path : str,
    normalization_factors : dict,
    goes_aws_url_folder : str = "noaa-goes16/ABI-L2-CMIPF",
    goes_channel : str = 'C01',
):

    """ Runs the inference on a full GOES image for a given datetime using the pre-trained model from the function `train_test_model`.

    Args:
            datetime (str): Date and time in 'YYYY-MM-DD HH:MM:SS'
            scatterometer_data_path (str): Path to the pre-downloaded scatterometer data. Must be the same used in for extracting the matching orbits.
            result_path_folder (str): Path to the folder where the results of the model training are
            model_parameters (dict): Dictionary containing model parameters such as batch size, image size, learning rate, etc. Must be identical as those used in `train_test_model`.
            buoy_path (str): Path to the buoy data folder. Must be a folder containing the buoy data in `.nc` format. More details about how to download the buoy data can be found in the tutorial.
            normalization_factors (dict): Dictionary containing normalization factors (mean and std) for the images dataset.
            goes_aws_url_folder (str): AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF". Must be the same as used in the matching orbits extraction.
            goes_channel (str): GOES channel to extract, default is "C01" (visible).

    Returns:
            None, but saves the inference results in a folder named `inference_<datetime>` in the `result_path_folder`.
    """

    # Getting polar data from the scatterometer data path
    for file in os.listdir(scatterometer_data_path):
        if file.endswith(".nc"):
            polar_data = xr.open_dataset(
                os.path.join(scatterometer_data_path, file),
                engine="h5netcdf",
                drop_variables=["DQF"],
            )
            break
    else:
        print('WARNING : No .nc file found in the scatterometer data path, please check the path')

    if goes_channel == 'C01' or goes_channel == 'C03' or goes_channel == 'C05':
        for file in os.listdir('./satellite_indices/'):
                if file.endswith('.npy') and "1km" in file:
                    parallel_index = np.load(
                    os.path.join('./satellite_indices/', file), allow_pickle=True
            )
                    break
    elif goes_channel == 'C02':
        for file in os.listdir('./satellite_indices/'):
                if file.endswith('.npy') and "0.5km" in file:
                    parallel_index = np.load(
                    os.path.join('./satellite_indices/', file), allow_pickle=True
            )
                    break
    else :
        parallel_index = np.load('./satellite_indices/2km at nadir_index.npy', allow_pickle=True)

    datetime = pd.to_datetime(datetime)


    images, valid_lats, valid_lons, _ = extract_goes_production(datetime,polar_data,parallel_index,goes_channel,goes_aws_url_folder)
    lat_inference, lon_inference, wind_speeds_inference = inference_whole_image(result_path_folder,images,valid_lats,valid_lons,model_parameters,normalization_factors,
    )

    buoy_lat,buoy_lon,buoy_time,buoy_wind_speed,buoy_name = buoy_data_extract(buoy_path,polar_data,datetime)

    safe_file_name = datetime.strftime('%Y-%m-%d_%H-%M-%S')
    path_inference = f'{result_path_folder}/inference_{safe_file_name}/'
    if not os.path.exists(path_inference):
        os.makedirs(path_inference)

    print('INFO : Folder of inference created:', path_inference)

    goes_image = plot_goes_image(lat_inference,lon_inference,images,path_inference,buoy_name,buoy_lat,buoy_lon,datetime)
    plot_wind_speeds(lat_inference,lon_inference,wind_speeds_inference,path_inference,buoy_name,buoy_lat,buoy_lon,datetime)

    np.save(f'{path_inference}/data_goes_image.npy',goes_image)
    np.save(f'{path_inference}/data_wind_speeds.npy',wind_speeds_inference)
    np.save(f'{path_inference}/data_lat.npy',lat_inference)
    np.save(f'{path_inference}/data_lon.npy',lon_inference)

    reshaped_images = np.reshape(images, (160,340,128,128))

    for i in range(len(buoy_wind_speed)):
        try:
            lat_index = np.where(lat_inference == buoy_lat[i])[0][0]
            lon_index = np.where(lon_inference == buoy_lon[i])[1][0]

            buoy_image =  reshaped_images[lat_index,lon_index]
            threshold = np.max(buoy_image) * 0.6
            cloud_mask = np.where(buoy_image > threshold, 1,0)
            percentage_cloud = np.sum(cloud_mask)/np.size(cloud_mask)

            wind_speeds_inference_buoy = wind_speeds_inference[lat_index,lon_index]
            difference = wind_speeds_inference_buoy - buoy_wind_speed[i]
            print(f'EVAL : Buoy {buoy_name[i]} - Inference Wind Speed : {wind_speeds_inference_buoy} - Buoy Wind Speed : {buoy_wind_speed[i]} - Difference : {difference}, Percentage Cloud : {percentage_cloud}')

            update_buoy_comparison_csv(
                result_path_folder, datetime, buoy_name, i,
                wind_speeds_inference_buoy, buoy_wind_speed, difference, percentage_cloud
            )

        except:
            print(f'EVAL : Buoy {buoy_name[i]} not in the inference area')
            continue

inference_model(model, inference_loader, device)

Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

Parameters:

Name Type Description Default
model Module

The trained model to be used for inference.

required
inference_loader DataLoader

DataLoader for the inference dataset.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
inference_outputs ndarray

Outputs from the model on the inference dataset.

Source code in windscangeo\impl.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def inference_model(model, inference_loader, device):
    """
    Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

    Args:
        model (torch.nn.Module): The trained model to be used for inference.
        inference_loader (torch.utils.data.DataLoader): DataLoader for the inference dataset.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        inference_outputs (numpy.ndarray): Outputs from the model on the inference dataset.
    """

    with torch.no_grad():  # Disable gradient calculation for inference

        inference_outputs = []

        for images in inference_loader:
            images = images.to(device)

            outputs = model(images).squeeze(-1)

            # Append outputs to the list
            inference_outputs.append(outputs)

        inference_outputs = torch.cat(inference_outputs, dim=0)
        inference_outputs = inference_outputs.cpu()
        inference_outputs = inference_outputs.numpy()

    return inference_outputs

inference_run(images, model_parameters, model_path, normalization_factors)

Runs inference on the provided images using the specified model parameters and normalization factors.

Parameters:

Name Type Description Default
images ndarray

Array of images to be used for inference.

required
model_parameters dict

Dictionary containing model parameters such as batch size, image size, channels

required
model_path str

Path to the pre-trained model file.

required
normalization_factors dict

Dictionary containing normalization factors such as mean and standard deviation.

required

Returns:

Name Type Description
inference_output ndarray

Array of inference outputs (wind speeds).

Source code in windscangeo\func_inference.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def inference_run(
    images, model_parameters, model_path, normalization_factors
):

    """
    Runs inference on the provided images using the specified model parameters and normalization factors.

    Args:
        images (np.ndarray): Array of images to be used for inference.
        model_parameters (dict): Dictionary containing model parameters such as batch size, image size, channels
        model_path (str): Path to the pre-trained model file.
        normalization_factors (dict): Dictionary containing normalization factors such as mean and standard deviation.

    Returns:
        inference_output (np.ndarray): Array of inference outputs (wind speeds).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = model_parameters["batch_size"]
    image_height = model_parameters["image_size"]
    image_width = model_parameters["image_size"]
    in_channels = model_parameters["image_channels"]
    dropout_rate = model_parameters["dropout_rate"]
    model_choice = model_parameters["model_choice"]

    mean = normalization_factors["mean"]
    std = normalization_factors["std"]
    if model_choice == 'CNN':

        features_cnn = model_parameters["features_cnn"]
        kernel_size = model_parameters["kernel_size"]
        activation_cnn = model_parameters["activation_cnn"]
        activation_final = model_parameters["activation_final"]
        stride = model_parameters["stride"]

        print("model choice is CNN")
        model = ConventionalCNN(
            image_height,
            image_width,
            features_cnn,
            kernel_size,
            in_channels,
            activation_cnn,
            activation_final,
            stride,
            dropout_rate
        ).to(device)

    if model_choice == 'ViT':
        print('model choice is ViT')
        # Load the model
        model = ViT(
            img_size = (128, 128),
            patch_size = (8,8),
            n_channels = 1,
            d_model = 1024,
            nhead = 4,
            dim_feedforward = 2048,
            blocks = 8,
            mlp_head_units = [1024, 512],
            n_classes = 1,
        ).to(device)

    if model_choice == 'ResNet':
        model = ResNet50(num_classes=1, channels=1).to(device)



    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    inference_dataset = conventional_dataset_inference(
        images,
        transform=Normalize(mean,std),
    )
    inference_loader = DataLoader(inference_dataset, batch_size, shuffle=False)

    inference_output = inference_model(model, inference_loader, device)

    return inference_output

manage_saved_models(directory)

Manage saved model files in the specified directory by deleting older epoch files. Keeps only the latest epoch file and deletes all others. From @ Jing Sun

Parameters:

Name Type Description Default
directory str

The directory where model files are saved.

required
Source code in windscangeo\impl.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def manage_saved_models(directory):  # From @Jing
    """
    Manage saved model files in the specified directory by deleting older epoch files.
    Keeps only the latest epoch file and deletes all others. From @ Jing Sun

    Args:
        directory (str): The directory where model files are saved.
    """

    pattern = re.compile(r"epoch_(\d+)\.pth")
    epoch_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            match = pattern.match(file)
            if match:
                epoch_num = int(match.group(1))
                file_path = os.path.join(root, file)
                epoch_files.append((file_path, epoch_num))

    # Check if there are more than 5 files
    if len(epoch_files) > 1:
        epoch_files.sort(key=lambda x: x[1])
        files_to_delete = len(epoch_files) - 1

        for i in range(files_to_delete):
            os.remove(epoch_files[i][0])

package_data(images, numerical_data, filter=True, solar_conversion=False, verbose=True)

This function packages the images and numerical data into a format that can be used for training a machine learning model. The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data) The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True. The function will return the images and numerical data in a numpy array format.

Parameters:

Name Type Description Default
images ndarray

The GOES images corresponding to the observation data.

required
numerical_data dict

A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".

required
filter bool

If True, the function will filter out invalid images and fill in Nan values. Default is True.

True
solar_conversion bool

If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)

False
verbose bool

If True, the function will print progress information. Default is True.

True

Returns:

Name Type Description
images ndarray

The GOES images corresponding to the observation data.

numerical_data ndarray

The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

Source code in windscangeo\func.py
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
def package_data(
    images,
    numerical_data,
    filter=True,
    solar_conversion=False,
    verbose=True
):
    """
    This function packages the images and numerical data into a format that can be used for training a machine learning model.
    The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data)
    The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True.
    The function will return the images and numerical data in a numpy array format.

    Args:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (dict): A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".
        filter (bool): If True, the function will filter out invalid images and fill in Nan values. Default is True.
        solar_conversion (bool): If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)
        verbose (bool): If True, the function will print progress information. Default is True.

    Returns:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (numpy.ndarray): The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

    """
    if filter:
        (images, numerical_data) = filter_invalid(images, numerical_data)
        images = fill_nans(images)

    if solar_conversion:
        observation_lats = numerical_data["observation_lats"]
        observation_lons = numerical_data["observation_lons"]
        observation_times = numerical_data["observation_times"]

        sza, saa = vectorized_solar_angles(
            observation_lats, observation_lons, observation_times
        )

        sza_rad = np.deg2rad(sza)
        sza_sin = np.sin(sza_rad)
        sza_cos = np.cos(sza_rad)

        saa_rad = np.deg2rad(saa)
        saa_sin = np.sin(saa_rad)
        saa_cos = np.cos(saa_rad)

        # Add the solar angles to the numerical data dictionary

        numerical_data["sza_sin"] = sza_sin
        numerical_data["sza_cos"] = sza_cos
        numerical_data["saa_sin"] = saa_sin
        numerical_data["saa_cos"] = saa_cos

        print("Data Preparation : converted to solar angles (sza, saa)")
        print("Data Preparation : returning images, numerical_data")
        return images, numerical_data

    else:

        return images, numerical_data

patchify(batch, patch_size)

Patchify the batch of images

Shape

batch: (b, h, w, c) output: (b, nh, nw, ph, pw, c)

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def patchify(batch, patch_size):
    """
    Patchify the batch of images

    Shape:
        batch: (b, h, w, c)
        output: (b, nh, nw, ph, pw, c)

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    b, c, h, w = batch.shape
    ph, pw = patch_size
    nh, nw = h // ph, w // pw

    batch_patches = torch.reshape(batch, (b, c, nh, ph, nw, pw))
    batch_patches = torch.permute(batch_patches, (0, 1, 2, 4, 3, 5))

    return batch_patches

plot_cloud_cover(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be used for cloud cover calculation.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def plot_cloud_cover(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be used for cloud cover calculation.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = np.mean(images, axis=(2,3))
    threshold = 0.11
    cloud_mask = np.where(mean_images > threshold, 1, 0)
    cloud_mask = cloud_mask.reshape(160,340)

    plot_cloud_mask(lat_inference,lon_inference,cloud_mask,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice)

    percentage_cloud = np.sum(cloud_mask)/cloud_mask.size
    print('Cloud coverage : ',percentage_cloud*100,'%')

    return cloud_mask, percentage_cloud

plot_cloud_mask(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Cloud mask data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def plot_cloud_mask(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Cloud mask data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """

    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, wind_speeds_inference,
        shading='auto', cmap='Blues',
        vmin=0, vmax=1
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='0 = Clear, 1 = Cloudy')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)

    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))

    # Plot each buoy in a single color
    for i, buoy_id in enumerate(unique_buoys):
        # Pick a distinct color from tab20
        color = color_map(i)

        # Identify the indices belonging to this buoy
        mask = np.array(buoy_name_flat) == buoy_id

        # Scatter just those points
        ax.scatter(
            np.array(buoy_lon)[mask],
            np.array(buoy_lat)[mask],
            s=100,
            color=color,            # Set the fill color
            edgecolor='black',
            linewidth=1,
            zorder=12,
            label=buoy_id           # Use buoy_id as the legend label
        )

    # Create the legend
    leg = ax.legend(
        title="Buoy Stations",
        loc="upper right",
        bbox_to_anchor=(1.0, 1.0),
        bbox_transform=ax.transAxes
    )

    # Ensure legend is above all other layers
    leg.set_zorder(999)

    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'Cloud Mask at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())


    plt.savefig(f'{path_folder}/plot_cloud_mask.png')

plot_goes_image(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def plot_goes_image(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = images[:,:,64,64]
    mean_images = mean_images.ravel()
    mean_images = mean_images.reshape(160,340)


    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())


    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, mean_images,
        shading='auto',
        vmin=0,
        vmax=1,
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='Brightness Temperature (K) - C01')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)


    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))


    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'GOES image at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())

    plt.savefig(f'{path_folder}/plot_goes_image.png')

    mean_images = np.array(mean_images)

    return mean_images

plot_save_loss(best_val_outputs, best_val_labels, train_losses, val_losses, path_folder, saving=False)

Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
train_losses list

List of training losses per epoch.

required
val_losses list

List of validation losses per epoch.

required
path_folder str

Path to save the plot and optionally the outputs and labels.

required
saving bool

If True, save the best validation outputs and labels. Default is False.

False
Source code in windscangeo\func_ml.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def plot_save_loss(
    best_val_outputs,
    best_val_labels,
    train_losses,
    val_losses,
    path_folder,
    saving=False,
):
    """
    Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        train_losses (list): List of training losses per epoch.
        val_losses (list): List of validation losses per epoch.
        path_folder (str): Path to save the plot and optionally the outputs and labels.
        saving (bool, optional): If True, save the best validation outputs and labels. Default is False.
    """
    # After training, save only the best validation outputs and labels
    if saving:
        np.save(
            os.path.join(path_folder, "best_validation_outputs.npy"), best_val_outputs
        )
        np.save(
            os.path.join(path_folder, "best_validation_labels.npy"), best_val_labels
        )

    num_epochs = len(train_losses)
    # Plotting the training and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
    plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    text_str = f"num_epochs = {num_epochs}, train loss = {train_losses[-1]:.2f}, validation loss = {val_losses[-1]:.2f}"
    plt.text(
        0.05,
        0.05,
        text_str,
        ha="left",
        va="bottom",
        transform=plt.gca().transAxes,  # Ensures the coordinates are relative to the axes (0 to 1 range)
    )
    plt.title("Training and Validation Loss Over Epochs")
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(path_folder, "loss_plot.png"))

plot_wind_speeds(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Wind speed data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def plot_wind_speeds(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Wind speed data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    time_str = time_choice.strftime('%Y-%m-%d %H:%M:%S')
    date, time = time_str.split(' ')
    lat = lat_inference
    lon = lon_inference
    wind_speeds = wind_speeds_inference


    buoy_names = buoy_name
    buoy_lat = buoy_lat
    buoy_lon = buoy_lon


    # nighttime mask

    lat_flat = lat.flatten()
    lon_flat = lon.flatten()
    time_flat = np.full(len(lat_flat), pd.Timestamp(f'{date} {time}'), dtype='datetime64[ns]')


    sza, saa = vectorized_solar_angles(lat_flat, lon_flat, time_flat)
    saa = np.reshape(saa,lat.shape)
    sza = np.reshape(sza,lat.shape)

    night_time_mask = np.where(sza > 90, 1, np.nan)
    cmap = ListedColormap(['white'])

    ###############

    min_lon, max_lon, min_lat, max_lat = -70, 0, -12, 20

    # add coastlines and gridlines
    fig = plt.figure(figsize=(22, 10), dpi=100)
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    levels = np.arange(0, 13, 1)
    line_colour = 'black'
    line_colours = ['black' for i in levels]
    ax.title.set_text(f'Wind Speed prediction from C01 GOES image (m/s) {date} {time} ')
    ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15,zorder = 5)
    ax.contourf(lon, lat, night_time_mask, transform=ccrs.PlateCarree(),cmap=cmap,zorder = 10)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.LAND, edgecolor='black',zorder= 20)
    ax.set_xticks(np.arange(-70, 1, 10), crs=ccrs.PlateCarree())
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0f}°E'))
    ax.set_yticks(np.arange(-15, 21, 5), crs=ccrs.PlateCarree())
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0f}°N'))
    ax.set_extent((-70, 0, -12, 20))
    ax.hlines(0, -70, 0, color='red', linewidth=1.5, linestyle='--', zorder= 10)
    fig.colorbar(ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15), ax=ax, orientation='vertical', aspect=50, label='Wind Speed (m/s)')
    ax.gridlines(color='gray', linestyle='--', alpha=0.5,zorder= 999)

    for buoy in range(len(buoy_names)):
        lon_b, lat_b = buoy_lon[buoy], buoy_lat[buoy]

        # Skip if out of bounds
        if not (min_lon <= lon_b <= max_lon and min_lat <= lat_b <= max_lat):
            continue

        ax.plot(lon_b, lat_b, 'o', color='red', markersize=5, transform=ccrs.PlateCarree(), label=buoy_names[buoy],zorder = 999)

        ax.text(
            lon_b + 0.2, lat_b + 0.2, buoy_names[buoy],
            fontsize=9, color='white',
            transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='black', edgecolor='none', boxstyle='square,pad=0.2'),zorder = 999
        )
    plt.savefig(f'{path_folder}/plot_wind_speeds.png')

rmse_per_range(model_output, target, path_folder)

Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

Parameters:

Name Type Description Default
model_output list or ndarray

Model outputs for the validation dataset.

required
target list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the CSV file.

required

Returns: pd.DataFrame: DataFrame containing the RMSE and count for each range.

Source code in windscangeo\func_ml.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def rmse_per_range(model_output, target,path_folder):
    """
    Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

    Args:
        model_output (list or np.ndarray): Model outputs for the validation dataset.
        target (list or np.ndarray): True labels for the validation dataset.
        path_folder (str): Path to save the CSV file.
    Returns:
        pd.DataFrame: DataFrame containing the RMSE and count for each range.
    """

    max_target = np.max(target)
    bins = np.arange(0, max_target, 1)
    rmse = np.zeros(len(bins))
    count = np.zeros(len(bins))
    results = []

    for i in range(len(bins)-1):
        idx = np.where((target >= bins[i]) & (target <= bins[i+1]))
        rmse[i] = np.sqrt(np.mean((model_output[idx] - target[idx])**2))
        count[i] = len(idx[0])
        print(f"EVAL : Range {bins[i]} m/s - {bins[i+1]} m/s: RMSE = {rmse[i]}, count = {int(count[i])}")
        results.append({'bin_start': bins[i], 'bin_end': bins[i+1], 'rmse': rmse[i], 'count': count[i]})

    df = pd.DataFrame(results)
    df.to_csv(f'{path_folder}/rmse_per_range.csv')
    return df

save_overpass_time(time_list, name_scatter)

This function prints the overpass time of the scatterometer.

Parameters:

Name Type Description Default
time_list ndarray

The measurement time values of the scatterometer data.

required
name_scatter str

The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

required

Returns:

Type Description

None

Source code in windscangeo\func.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def save_overpass_time(time_list,name_scatter):
    """
    This function prints the overpass time of the scatterometer.

    Args:
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        name_scatter (str): The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

    Returns:
        None 

    """
    formated_time = time_list.astype('datetime64[ns]')
    hour_minute = formated_time.astype('datetime64[m]')
    unique_hour_minute = np.unique(hour_minute)

    filtered = [unique_hour_minute[0]]

    delta = np.timedelta64(1, 'h')

    for time in unique_hour_minute[1:]:
        if time - filtered[-1] >= delta:
            filtered.append(time)

    time_only = []
    for time in filtered:
        time = str(time).split('T')[1]
        time_only.append(time)
    print(f"ORBIT : {name_scatter} overpass time : {time_only}")

savedataseperated(ScatterData, main_parameter, verbose=True)

This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel of the scatterometer data and saves it to a numpy file.

Parameters:

Name Type Description Default
ScatterData Dataset

The ASCAT dataset containing the scatterometer data.

required
main_parameter DataArray

The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

required

Returns:

lat_list (numpy.ndarray): The latitude values of the scatterometer data.
lon_list (numpy.ndarray): The longitude values of the scatterometer data.
time_list (numpy.ndarray): The measurement time values of the scatterometer data.
main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

this function saves the data locally to a folder called data_processed_scat

Source code in windscangeo\func.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def savedataseperated(ScatterData, main_parameter,verbose=True):
    """
    This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel
    of the scatterometer data and saves it to a numpy file.

    Args:
        ScatterData (xarray.Dataset): The ASCAT dataset containing the scatterometer data.
        main_parameter (xarray.DataArray): The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

    Returns:

        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

    this function saves the data locally to a folder called data_processed_scat
    """
    lat_full, lon_full, time_full = ScatterData.indexes.values()
    measurement_time_full = ScatterData.measurement_time

    lat_full = np.array(lat_full)
    lon_full = np.array(lon_full)
    measurement_time_full = np.array(measurement_time_full)
    main_parameter = np.array(main_parameter)

    index = np.argwhere(~np.isnan(main_parameter))

    index_list = []
    lat_list = []
    lon_list = []
    time_list = []
    wind_speed_list = []

    name_scatter = ScatterData.source

    for t, i, j in index:

        # print(t,'= time', i,'=row', j, '=column')
        index_list.append((t, i, j))

        # print(measurement_time_full[t, i, j].astype('datetime64[ns]'))
        time_list.append(measurement_time_full[t, i, j])

        # print(lat_full[i])
        lat_list.append(lat_full[i])

        # print(lon_full[j])
        lon_list.append(lon_full[j])

        # print(AllWindSpeeds[t, i, j])
        wind_speed_list.append(main_parameter[t, i, j])

    lat_list = np.array(lat_list)
    lon_list = np.array(lon_list)
    time_list = np.array(time_list)
    wind_speed_list = np.array(wind_speed_list)

    lat_list, lon_list, time_list, wind_speed_list = sort_by_time(
        lat_list, lon_list, time_list, wind_speed_list
    )
    if verbose:
        save_overpass_time(time_list,name_scatter)    

    return lat_list, lon_list, time_list, wind_speed_list

snap_to_nearest(values, reference_array, cutoff=1.0)

Snap an array of values to the nearest values in a reference array. If the difference is greater than the cutoff, the original value is returned.

Parameters:

Name Type Description Default
values ndarray

Array of values to snap.

required
reference_array ndarray

Array of reference values.

required
cutoff float

Maximum allowable difference for snapping.

1.0

Returns:

Type Description

np.ndarray: Snapped values.

Source code in windscangeo\func_inference.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def snap_to_nearest(values, reference_array, cutoff=1.0):
    """
    Snap an array of values to the nearest values in a reference array.
    If the difference is greater than the cutoff, the original value is returned.

    Args:
        values (np.ndarray): Array of values to snap.
        reference_array (np.ndarray): Array of reference values.
        cutoff (float): Maximum allowable difference for snapping.

    Returns:
        np.ndarray: Snapped values.
    """
    # Convert inputs to NumPy arrays for compatibility
    values = np.asarray(values)
    reference_array = np.asarray(reference_array)

    # Find the nearest reference value for each input value
    # Reshape reference_array to allow broadcasting
    reference_array = reference_array.reshape(1, -1)
    differences = np.abs(values.reshape(-1, 1) - reference_array)
    nearest_indices = np.argmin(differences, axis=1)
    nearest_values = reference_array.ravel()[nearest_indices]

    # Apply the cutoff condition
    snap_mask = np.abs(values - nearest_values) <= cutoff
    snapped_values = np.where(snap_mask, nearest_values, values)

    return snapped_values

sort_by_time(lat_list, lon_list, time_list, wind_speed_list)

This function sorts the output of savedataseperated() by time. This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

Parameters:

Name Type Description Default
lat_list ndarray

The latitude values of the scatterometer data.

required
lon_list ndarray

The longitude values of the scatterometer data.

required
time_list ndarray

The measurement time values of the scatterometer data.

required
wind_speed_list ndarray

The wind speed values of the scatterometer data.

required

Returns:

Name Type Description
lat_list_sorted ndarray

The sorted latitude values of the scatterometer data.

lon_list_sorted ndarray

The sorted longitude values of the scatterometer data.

time_list_sorted ndarray

The sorted measurement time values of the scatterometer data.

wind_speed_list_sorted ndarray

The sorted wind speed values of the scatterometer data.

Source code in windscangeo\func.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def sort_by_time(lat_list, lon_list, time_list, wind_speed_list):
    """
    This function sorts the output of savedataseperated() by time.
    This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

    Args:
        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        wind_speed_list (numpy.ndarray): The wind speed values of the scatterometer data.

    Returns:
        lat_list_sorted (numpy.ndarray): The sorted latitude values of the scatterometer data.
        lon_list_sorted (numpy.ndarray): The sorted longitude values of the scatterometer data.
        time_list_sorted (numpy.ndarray): The sorted measurement time values of the scatterometer data.
        wind_speed_list_sorted (numpy.ndarray): The sorted wind speed values of the scatterometer data.

    """
    # Get the indices that would sort the measurement_time array
    sorted_indices = np.argsort(time_list)

    # Reorder the arrays using the sorted indices
    time_list_sorted = time_list[sorted_indices]
    lat_list_sorted = lat_list[sorted_indices]
    lon_list_sorted = lon_list[sorted_indices]
    speed_list_sorted = wind_speed_list[sorted_indices]

    return lat_list_sorted, lon_list_sorted, time_list_sorted, speed_list_sorted

test_model(model, test_loader, criterion, device)

Evaluate the model on the test dataset and return the outputs, targets, and average loss.

Parameters:

Name Type Description Default
model Module

The trained model to be evaluated.

required
test_loader DataLoader

DataLoader for the test dataset.

required
criterion Module

Loss function to be used for evaluation.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
test_outputs ndarray

Outputs from the model on the test dataset.

test_targets ndarray

Targets corresponding to the test outputs.

avg_test_loss float

Average loss on the test dataset.

Source code in windscangeo\impl.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def test_model(model, test_loader, criterion, device):
    """
    Evaluate the model on the test dataset and return the outputs, targets, and average loss.

    Args:
        model (torch.nn.Module): The trained model to be evaluated.
        test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
        criterion (torch.nn.Module): Loss function to be used for evaluation.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        test_outputs (numpy.ndarray): Outputs from the model on the test dataset.
        test_targets (numpy.ndarray): Targets corresponding to the test outputs.
        avg_test_loss (float): Average loss on the test dataset.
    """

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for inference

        test_outputs = []
        test_targets = []
        test_loss = 0.0

        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            # Append outputs to the list
            test_outputs.append(outputs)
            test_targets.append(targets)

        avg_test_loss = test_loss / len(test_loader)
        print(f"EVAL : Test Loss: {avg_test_loss}")

        test_outputs = torch.cat(test_outputs, dim=0)
        test_outputs = test_outputs.cpu()
        test_outputs = test_outputs.numpy()

        test_targets = torch.cat(test_targets, dim=0)
        test_targets = test_targets.cpu()
        test_targets = test_targets.numpy()

    return test_outputs, test_targets, avg_test_loss

train_model(model, train_loader, val_loader, num_epochs, lr, weight_decay, criterion, device, optimizer_choice, patience_epochs, patience_loss, path_folder)

Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
train_loader DataLoader

DataLoader for the training dataset.

required
val_loader DataLoader

DataLoader for the validation dataset.

required
num_epochs int

Number of epochs to train the model.

required
lr float

Learning rate for the optimizer.

required
weight_decay float

Weight decay for the optimizer.

required
criterion Module

Loss function to be used.

required
device device

Device to run the model on (CPU or GPU).

required
optimizer_choice str

Choice of optimizer ('Adam', 'SGD', 'RMSprop').

required
patience_epochs int

Number of epochs to wait before stopping if no improvement in validation loss.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required
path_folder str

Path to save the model checkpoints.

required

Returns:

Name Type Description
best_val_outputs ndarray

Best validation outputs from the model.

best_val_labels ndarray

Best validation labels corresponding to the outputs.

best_model Module

The best model based on validation loss.

train_losses list

List of training losses for each epoch.

val_losses list

List of validation losses for each epoch.

Source code in windscangeo\impl.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs,
    lr,
    weight_decay,
    criterion,
    device,
    optimizer_choice,
    patience_epochs,
    patience_loss,
    path_folder,
):

    """
    Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

    Args:
        model (torch.nn.Module): The model to be trained.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        num_epochs (int): Number of epochs to train the model.
        lr (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay for the optimizer.
        criterion (torch.nn.Module): Loss function to be used.
        device (torch.device): Device to run the model on (CPU or GPU).
        optimizer_choice (str): Choice of optimizer ('Adam', 'SGD', 'RMSprop').
        patience_epochs (int): Number of epochs to wait before stopping if no improvement in validation loss.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.
        path_folder (str): Path to save the model checkpoints.

    Returns:
        best_val_outputs (numpy.ndarray): Best validation outputs from the model.
        best_val_labels (numpy.ndarray): Best validation labels corresponding to the outputs.
        best_model (torch.nn.Module): The best model based on validation loss.
        train_losses (list): List of training losses for each epoch.
        val_losses (list): List of validation losses for each epoch.
    """


    if optimizer_choice == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Invalid optimizer choice. Please choose 'Adam' or 'SGD'.")

    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.8)

    # Initialize lists to store loss values and validation predictions
    train_losses = []
    val_losses = []
    best_val_loss = float("inf")
    best_val_outputs = None
    best_val_labels = None


    pbar = tqdm(range(num_epochs), desc="TRAIN : Training Progress")    
    for epoch in pbar:
        # Training Phase
        model.train()
        running_loss = 0.0

        for images, targets in train_loader:

            # Move data to GPU
            images = images.to(device)
            targets = targets.to(device)            

            # Forward pass
            optimizer.zero_grad()
            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)  # Ensure target shape matches output
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate average training loss for the epoch
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        #print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}"})

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_outputs = []  # Temporary list to store outputs for this epoch
        val_labels = []  # Temporary list to store labels for this epoch
        with torch.no_grad():
            for images, targets in val_loader:
                # Move data to GPU
                images = images.to(device)
                targets = targets.to(device)

                # Get model output
                outputs = model(images).squeeze(-1)

                # Calculate loss
                loss = criterion(outputs, targets)  # Ensure target shape matches output
                val_loss += loss.item()

                # Append outputs and targets to lists
                val_outputs.append(outputs.cpu())  # Move to CPU for concatenation
                val_labels.append(targets.cpu())

        # Concatenate outputs and labels across all batches to ensure all samples are included
        val_outputs = torch.cat(val_outputs, dim=0).numpy()
        val_labels = torch.cat(val_labels, dim=0).numpy()

        # Calculate average validation loss for the epoch
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        #print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}", "Val Loss": f"{avg_val_loss:.4f}"})

        # Check if this is the best validation loss so far and store best outputs/labels
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_outputs = val_outputs
            best_val_labels = val_labels

            best_model = model
            best_model_state = model.state_dict()
            torch.save(
                best_model_state, os.path.join(path_folder, f"./epoch_{epoch + 1}.pth")
            )


        manage_saved_models(path_folder)

        if early_stopping(val_losses, patience_epochs, patience_loss):
            return (
                best_val_outputs,
                best_val_labels,
                best_model,
                train_losses,
                val_losses,
            )
        scheduler.step()



    return best_val_outputs, best_val_labels, model, train_losses, val_losses

train_test_model(saved_file_path, run_name, model_parameters, normalization_factors)

Trains and tests a model using the provided parameters and data from a saved file (from extract_matching_orbits).

Parameters:

Name Type Description Default
saved_file_path str

Path to the saved .npz file containing preloaded data.

required
run_name str

Name of the run, used to create a folder for saving results.

required
model_parameters dict

Dictionary containing model parameters such as batch size, image size, learning rate, etc (Supports CNN, ResNet, ViT). See tutorial for details.

required
normalization_factors dict

Dictionary containing normalization factors (mean and std) for the images dataset.

required

Returns:

Name Type Description
result_path_folder str

Path to the folder where results are saved.

model_parameters should contain the following (dictionary, values can be changed as needed): "batch_size" : 256, "image_size": 128, "image_channels" : 1,
"model_choice" : "ResNet", # or "CNN" or"ViT" "criterion" : nn.MSELoss(), # or any other PyTorch loss function "optimizer_choice" : "Adam", "learning_rate" : 0.003305753102490767, "weight_decay" : 0.00000148842072509874, "dropout_rate" : 0.2752124679248082, "num_epochs" : 10, "patience_epochs" : 20, # early stopping "patience_loss" : 0.001,

    # The following additional parameters are required with the CNN :
    "activation_cnn" : nn.ReLU(),
    "activation_final" : nn.Identity(),
    "kernel_size" : 3,
    "features_cnn" : [64,64,64,64],
    "stride" : 1,

normalization_factors should contain the following (dictionary, values can be changed as needed): "mean" : 0.0, # mean of the images dataset "std" : 1.0, # std of the images dataset

Source code in windscangeo\main_func.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def train_test_model(
    saved_file_path: str,
    run_name: str,
    model_parameters: dict,
    normalization_factors: dict,
    ):

    """
    Trains and tests a model using the provided parameters and data from a saved file (from `extract_matching_orbits`).

    Args:
        saved_file_path (str): Path to the saved .npz file containing preloaded data.
        run_name (str): Name of the run, used to create a folder for saving results.
        model_parameters (dict): Dictionary containing model parameters such as batch size, image size, learning rate, etc (Supports CNN, ResNet, ViT). See tutorial for details.
        normalization_factors (dict): Dictionary containing normalization factors (mean and std) for the images dataset.

    Returns:
        result_path_folder (str): Path to the folder where results are saved.  


    model_parameters should contain the following (dictionary, values can be changed as needed):
            "batch_size" : 256,
            "image_size": 128, 
            "image_channels" : 1,  
            "model_choice" : "ResNet", # or "CNN" or"ViT"
            "criterion" : nn.MSELoss(), # or any other PyTorch loss function
            "optimizer_choice" : "Adam", 
            "learning_rate" : 0.003305753102490767,
            "weight_decay" : 0.00000148842072509874,
            "dropout_rate" : 0.2752124679248082,
            "num_epochs" : 10, 
            "patience_epochs" : 20, # early stopping
            "patience_loss" : 0.001,

            # The following additional parameters are required with the CNN :
            "activation_cnn" : nn.ReLU(),
            "activation_final" : nn.Identity(),
            "kernel_size" : 3,
            "features_cnn" : [64,64,64,64],
            "stride" : 1,

    normalization_factors should contain the following (dictionary, values can be changed as needed):
            "mean" : 0.0, # mean of the images dataset
            "std" : 1.0, # std of the images dataset

    """

    # Ignore warnings for division by zero and invalid operations
    np.seterr(divide='ignore', invalid='ignore')

    # create a folder for the experiment
    result_path_folder = create_folder(run_name)

    # load the model parameters
    batch_size = model_parameters["batch_size"]
    image_height = model_parameters["image_size"]
    image_width = model_parameters["image_size"]
    in_channels = model_parameters["image_channels"]
    lr = model_parameters["learning_rate"]
    weight_decay = model_parameters["weight_decay"]
    criterion = model_parameters["criterion"]
    optimizer_choice = model_parameters["optimizer_choice"]
    dropout_rate = model_parameters["dropout_rate"]


    num_epochs = model_parameters["num_epochs"]
    patience_epochs = model_parameters["patience_epochs"]
    patience_loss = model_parameters["patience_loss"]
    model_choice = model_parameters["model_choice"]

    # specify the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"INFO : Pytorch is using device: {device}")

    # Load the model

    if model_choice == "CNN":

        activation_cnn = model_parameters["activation_cnn"]
        activation_final = model_parameters["activation_final"]
        kernel_size = model_parameters["kernel_size"]
        features_cnn = model_parameters["features_cnn"]
        stride = model_parameters["stride"]

        model = ConventionalCNN(
            image_height,
            image_width,
            features_cnn,
            kernel_size,
            in_channels,
            activation_cnn,
            activation_final,
            stride,
            dropout_rate
        ).to(device)
        print("INFO : model choice is CNN")

    if model_choice == 'ViT':
        model = ViT(
            img_size = (128, 128),
            patch_size = (16,16),
            n_channels = 1,
            d_model = 1024,
            nhead = 4,
            dim_feedforward = 2048,
            blocks = 8,
            mlp_head_units = [1024, 512],
            n_classes = 1,
        ).to(device)
        print('INFO : model choice is ViT')

    if model_choice == 'ResNet':
        print('INFO : model choice is ResNet')
        model = ResNet50(num_classes=1, channels=1).to(device)

        # use the best saved model for test
        # print('using preloaded model, continuing training')
        # model_file = [file for file in os.listdir(path_folder) if file.endswith(".pth")][0]
        # model_path = os.path.join(path_folder, model_file)
        # model.load_state_dict(torch.load(model_path,map_location=torch.device(device)) )


    # loading normalization factors
    mean = normalization_factors["mean"]
    std = normalization_factors["std"]

    data_file = np.load(saved_file_path,allow_pickle=True)
    data_file_images = np.array(data_file['images'])
    data_file_numerical_data = data_file['numerical_data'].item()['observation_wind_speeds']

    print("INFO : Data loaded from file:", saved_file_path)

    train_images, rest_images, train_targets, rest_targets = sklearn.model_selection.train_test_split(data_file_images,data_file_numerical_data,train_size = 0.8,random_state=42)
    val_images, test_images, val_targets, test_targets = sklearn.model_selection.train_test_split(rest_images,rest_targets,train_size = 0.5,random_state=42)
    print("INFO : Data split into train (0.8), validation (0.1) and test sets (0.1)")
    # loading the data 


    train_dataset = conventional_dataset(
        train_images,
        train_targets,
        #transform=Normalize(mean,std),
    )

    validation_dataset = conventional_dataset(
        val_images,
        val_targets,
        #transform=Normalize(mean,std),
    )

    test_dataset = conventional_dataset(
        test_images,
        test_targets,
        #transform=Normalize(mean,std),
    )

    # Native pytorch dataloader

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last= False,
        pin_memory=True,
    )
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last= False,
        pin_memory=True,

    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last= False,
        pin_memory=True,

    )

    # Train the model ! 

    best_val_outputs, best_val_labels, model, train_losses, val_losses = train_model(
        model,
        train_loader,
        validation_loader,
        num_epochs,
        lr,
        weight_decay,
        criterion,
        device,
        optimizer_choice,
        patience_epochs,
        patience_loss,
        result_path_folder,
    )

    # Ploting and saving the results

    plot_save_loss(
        best_val_outputs,
        best_val_labels,
        train_losses,
        val_losses,
        result_path_folder,
        saving=False,
    )

    model_file = [file for file in os.listdir(result_path_folder) if file.endswith(".pth")][0]
    model_path = os.path.join(result_path_folder, model_file)
    model.load_state_dict(torch.load(model_path,map_location=torch.device(device)) )

    # Running the test
    test_output, test_target, test_loss = test_model(model, test_loader, criterion, device)

    # Plotting the results
    error_plot(test_output, test_target, result_path_folder)
    rmse_per_range(test_output, test_target, result_path_folder)  # ADD TO METADATA

    np.save(os.path.join(result_path_folder, "test_loss.npy"),
            test_loss)


    # saving the test output and labels
    np.save(
        os.path.join(result_path_folder, "test_output.npy"),
        test_output,
    )
    np.save(
        os.path.join(result_path_folder, "test_labels.npy"),
        test_target,
    )

    return result_path_folder

vectorized_solar_angles(lat, lon, time_utc)

This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

Parameters:

Name Type Description Default
lat ndarray

The latitude values of the scatterometer data.

required
lon ndarray

The longitude values of the scatterometer data.

required
time_utc ndarray

The observation times in UTC.

required

Returns:

Name Type Description
sza ndarray

The solar zenith angle in degrees.

saa ndarray

The solar azimuth angle in degrees.

Source code in windscangeo\func_ml.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def vectorized_solar_angles(lat, lon, time_utc):

    """
    This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

    Args:
        lat (numpy.ndarray): The latitude values of the scatterometer data.
        lon (numpy.ndarray): The longitude values of the scatterometer data.
        time_utc (numpy.ndarray): The observation times in UTC.

    Returns:
        sza (numpy.ndarray): The solar zenith angle in degrees.
        saa (numpy.ndarray): The solar azimuth angle in degrees.
    """

    # Convert time to Julian Day
    timestamp = pd.to_datetime(time_utc).tz_localize(None)
    jd = (
        timestamp.astype("datetime64[ns]").astype(np.int64) / 86400000000000 + 2440587.5
    )
    d = jd - 2451545.0  # Days since J2000

    # Mean longitude, mean anomaly, ecliptic longitude
    g = np.deg2rad((357.529 + 0.98560028 * d) % 360)  # Mean anomaly
    q = np.deg2rad((280.459 + 0.98564736 * d) % 360)  # Mean longitude
    L = (q + np.deg2rad(1.915) * np.sin(g) + np.deg2rad(0.020) * np.sin(2 * g)) % (
        2 * np.pi
    )  # Ecliptic long

    # Obliquity of the ecliptic
    e = np.deg2rad(23.439 - 0.00000036 * d)

    # Sun declination
    sin_delta = np.sin(e) * np.sin(L)
    delta = np.arcsin(sin_delta)

    # Equation of time (in minutes)
    E = 229.18 * (
        0.000075
        + 0.001868 * np.cos(g)
        - 0.032077 * np.sin(g)
        - 0.014615 * np.cos(2 * g)
        - 0.040849 * np.sin(2 * g)
    )

    # Convert time to fractional hours (UTC)
    fractional_hour = timestamp.hour + timestamp.minute / 60 + timestamp.second / 3600

    # Solar time correction
    time_offset = E + 4 * lon  # lon in degrees
    tst = fractional_hour * 60 + time_offset  # True Solar Time in minutes
    ha = np.deg2rad((tst / 4 - 180) % 360)  # Hour angle in radians

    # Convert lat/lon to radians
    lat_rad = np.deg2rad(lat)

    # Solar zenith angle
    cos_zenith = np.sin(lat_rad) * np.sin(delta) + np.cos(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    zenith = np.rad2deg(np.arccos(np.clip(cos_zenith, -1, 1)))  # in degrees

    # Solar saa angle
    sin_saa = -np.sin(ha) * np.cos(delta)
    cos_saa = np.cos(lat_rad) * np.sin(delta) - np.sin(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    saa = np.rad2deg(np.arctan2(sin_saa, cos_saa))
    saa = (saa + 360) % 360  # Normalize

    return zenith, saa

calculate_degrees(file_id)

This function calculates the latitude and longitude of the GOES ABI fixed grid projection. This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

Parameters:

Name Type Description Default
file_id Dataset

The xarray dataset containing the GOES ABI fixed grid projection variables.

required

Returns:

Name Type Description
abi_lat ndarray

The latitude of the GOES ABI fixed grid projection.

abi_lon ndarray

The longitude of the GOES ABI fixed grid projection.

Source code in windscangeo\func.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def calculate_degrees(file_id):
    """This function calculates the latitude and longitude of the GOES ABI fixed grid projection. 
    This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

    Args:
        file_id (xarray.Dataset): The xarray dataset containing the GOES ABI fixed grid projection variables.

    Returns:
        abi_lat (numpy.ndarray): The latitude of the GOES ABI fixed grid projection.
        abi_lon (numpy.ndarray): The longitude of the GOES ABI fixed grid projection.


    """

    # Read in GOES ABI fixed grid projection variables and constants
    x_coordinate_1d = file_id.variables["x"][:]  # E/W scanning angle in radians
    y_coordinate_1d = file_id.variables["y"][:]  # N/S elevation angle in radians
    projection_info = file_id.goes_imager_projection
    lon_origin = projection_info.longitude_of_projection_origin
    H = projection_info.perspective_point_height + projection_info.semi_major_axis
    r_eq = projection_info.semi_major_axis
    r_pol = projection_info.semi_minor_axis

    # Create 2D coordinate matrices from 1D coordinate vectors
    x_coordinate_2d, y_coordinate_2d = np.meshgrid(x_coordinate_1d, y_coordinate_1d)

    # Equations to calculate latitude and longitude
    lambda_0 = (lon_origin * np.pi) / 180.0
    a_var = np.power(np.sin(x_coordinate_2d), 2.0) + (
        np.power(np.cos(x_coordinate_2d), 2.0)
        * (
            np.power(np.cos(y_coordinate_2d), 2.0)
            + (
                ((r_eq * r_eq) / (r_pol * r_pol))
                * np.power(np.sin(y_coordinate_2d), 2.0)
            )
        )
    )
    b_var = -2.0 * H * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    c_var = (H**2.0) - (r_eq**2.0)
    r_s = (-1.0 * b_var - np.sqrt((b_var**2) - (4.0 * a_var * c_var))) / (2.0 * a_var)
    s_x = r_s * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    s_y = -r_s * np.sin(x_coordinate_2d)
    s_z = r_s * np.cos(x_coordinate_2d) * np.sin(y_coordinate_2d)

    # Ignore numpy errors for sqrt of negative number; occurs for GOES-16 ABI CONUS sector data
    np.seterr(all="ignore")

    abi_lat = (180.0 / np.pi) * (
        np.arctan(
            ((r_eq * r_eq) / (r_pol * r_pol))
            * ((s_z / np.sqrt(((H - s_x) * (H - s_x)) + (s_y * s_y))))
        )
    )
    abi_lon = (lambda_0 - np.arctan(s_y / (H - s_x))) * (180.0 / np.pi)

    print("INFO : Latitude and longitude calculated")
    return abi_lat, abi_lon

create_folder(experiment_name)

Create a folder for saving results based on the experiment name.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment to create a folder for.

required

Returns:

Name Type Description
str

Path to the created folder.

Source code in windscangeo\func.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def create_folder(experiment_name):
    """
    Create a folder for saving results based on the experiment name.

    Args:
        experiment_name (str): Name of the experiment to create a folder for.
        If the folder already exists, it will not be created again.

    Returns:
        str: Path to the created folder.
    """

    path_folder = f"./results_folder/model_day_{experiment_name}"

    if not os.path.exists(path_folder):
        os.makedirs(path_folder)
        print(f"Folder created at {path_folder}")

    return path_folder

extract_goes(observation_times, observation_lats, observation_lons, scatterometer_data_path, goes_aws_url_folder, goes_channel='C01', goes_image_size=128, verbose=True)

This function extracts GOES images for the given observation times, latitudes, and longitudes. It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
scatterometer_data_path str

The path to the scatterometer data directory.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored.

required
goes_channel str

The channel of interest. Default is "C01".

'C01'
goes_image_size int

The size of the output images. Default is 128.

128
verbose bool

If True, prints progress information.

True

Returns:

Name Type Description
images ndarray

A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

Source code in windscangeo\func.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def extract_goes(
    observation_times,
    observation_lats,
    observation_lons,
    scatterometer_data_path,
    goes_aws_url_folder,
    goes_channel="C01",
    goes_image_size=128,
    verbose=True,
):
    """
    This function extracts GOES images for the given observation times, latitudes, and longitudes.
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data. 
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        scatterometer_data_path (str): The path to the scatterometer data directory.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored.
        goes_channel (str): The channel of interest. Default is "C01".
        goes_image_size (int): The size of the output images. Default is 128.
        verbose (bool): If True, prints progress information.

    Returns:
        images (numpy.ndarray): A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

    """

    for file in os.listdir(scatterometer_data_path):
        if file.endswith(".nc"):
            polar = xr.open_dataset(
                os.path.join(scatterometer_data_path, file),
                engine="h5netcdf",
                drop_variables=["DQF"],
            )
            break

        else:
            print('WARNING : No .nc file found in the scatterometer data path, please check the path')

    template_scatter = polar.isel(time=0)
    lat_grd, lon_grd = (
        template_scatter["latitude"].values,
        template_scatter["longitude"].values,
    )

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)

    values, counts = np.unique(observation_times, return_counts=True)

    all_urls = []  # getting unique URLS
    for value in values:
        urls = get_goes_url(value, goes_aws_url_folder,goes_channel)
        all_urls.append(urls)

    values_url, indices_url, counts_url = np.unique(
        all_urls, return_index=True, return_counts=True, axis=0
    )
    # Sort indices to "unsort" the URLs
    sorted_indices = sorted(range(len(indices_url)), key=lambda k: indices_url[k])
    values_url = [all_urls[indices_url[i]] for i in sorted_indices]

    # Reorder counts_url using the same sorted indices
    counts_url = [counts_url[i] for i in sorted_indices]

    compressed_urls = values_url
    compressed_counts = []
    start_idx = 0

    for size in counts_url:
        group_sum = counts[start_idx : start_idx + size].sum()
        compressed_counts.append(group_sum)
        start_idx += size

    width = goes_image_size
    height = goes_image_size

    images = np.zeros([len(observation_times), 1 , width, height], dtype=np.float32)

    total_idx = 0
    for unique_idx, unique_urls in tqdm(
        enumerate(compressed_urls),
        desc="INFO : Retrieving and processing GOES data",
        total=len(compressed_urls),
        disable=not verbose,
    ):


        for CH_idx, url_CH in enumerate(unique_urls):

            if url_CH == 0:
                images[total_idx, CH_idx] = np.zeros([width, height])
                continue

            with fs.open(url_CH, mode="rb") as f:

                ds = xr.open_dataset(
                    f, engine="h5netcdf", drop_variables=["DQF"]
                )  # this is the bottleneck

                parallel_index = index_parallel(
                    ds,
                    template_scatter,
                )
                for i in range(compressed_counts[unique_idx]):
                    images[total_idx + i, CH_idx] = get_image(
                        ds=ds,
                        parallel_index=parallel_index,
                        lat_grd=lat_grd,
                        lon_grd=lon_grd,
                        lat_search=observation_lats[total_idx + i],
                        lon_search=observation_lons[total_idx + i],
                        goes_image_size=goes_image_size,
                    )

        total_idx += compressed_counts[unique_idx]

    if verbose:
        print(
            f"INFO : Extracted {len(observation_times)} images from {len(compressed_urls)} GOES files."
        )
    return images

extract_goes_inference(date_time, parallel_index, channels='C01', goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF')

This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.) It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size (128x128).

Parameters:

Name Type Description Default
date_time datetime64

The time of the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
channels str or list

The channel(s) of interest. Default is "C01".

'C01'
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

'noaa-goes16/ABI-L2-CMIPF'

Returns:

Name Type Description
images list

A list of numpy arrays containing the extracted GOES images of shape (128, 128).

Source code in windscangeo\func.py
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def extract_goes_inference(date_time, parallel_index,channels="C01",goes_aws_url_folder= 'noaa-goes16/ABI-L2-CMIPF'):
    """
    This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.)
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create
    images of the specified size (128x128).

    Args:
        date_time (numpy.datetime64): The time of the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        channels (str or list): The channel(s) of interest. Default is "C01".
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

    Returns:
        images (list): A list of numpy arrays containing the extracted GOES images of shape (128, 128).
    """

    # ignore divide by zero errors which occur when the GOES data can't form a 128x128 image
    np.seterr(invalid='ignore', divide='ignore')

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)
    urls = get_goes_url(date_time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel= channels)
    with fs.open(urls[0], mode="rb") as f:
        print("INFO : Reading file:", urls[0])
        goes_image = xr.open_dataset(f)
        goes_image = goes_image.rename({"x": "x_index", "y": "y_index"})

        # Assign the index coordinates (if not already done)
        goes_image = goes_image.assign_coords(
            x_index=np.arange(goes_image.sizes["x_index"]),
            y_index=np.arange(goes_image.sizes["y_index"]),
        )

        images = []
        goes_image.load()
        print("INFO : Extracting images")
        for i in range(parallel_index.shape[0]):
            for j in range(parallel_index.shape[1]):
                try:
                    x_mean = parallel_index[i][j][1].mean().astype(int)
                    x_min = x_mean - 63
                    x_max = x_mean + 63
                    y_mean = parallel_index[i][j][0].mean().astype(int)
                    y_min = y_mean - 63
                    y_max = y_mean + 63
                    image = goes_image.CMI.sel(
                        x_index=slice(x_min, x_max), y_index=slice(y_min, y_max)
                    )

                    target_size = (128, 128)

                    padded_image = np.pad(
                        image,
                        (
                            (
                                (target_size[0] - image.shape[0]) // 2,
                                (target_size[0] - image.shape[0] + 1) // 2,
                            ),
                            (
                                (target_size[1] - image.shape[1]) // 2,
                                (target_size[1] - image.shape[1] + 1) // 2,
                            ),
                        ),
                        constant_values=0,
                    )

                except:
                    images.append(np.zeros((128, 128)))

                    continue
                images.append(padded_image)

        return images

extract_scatter(polar_data, date, lat_range, lon_range, verbose=True, main_variable='wind_speed')

This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range. The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

Parameters:

Name Type Description Default
polar_data Dataset

The scatterometer dataset (ASCAT, HYSCAT etc).

required
date datetime64

The time of the scatterometer data.

required
lat_range tuple

The latitude range of the scatterometer data.

required
lon_range tuple

The longitude range of the scatterometer data.

required
verbose bool

If True, the function will print the progress of the extraction.

True
main_variable str

The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

'wind_speed'

Returns:

Name Type Description
observation_times ndarray

The time of observation of the scatterometer data.

observation_lats ndarray

The latitude of the scatterometer data.

observation_lons ndarray

The longitude of the scatterometer data.

observation_main_parameter ndarray

main parameter extracted (wind_speed).

Source code in windscangeo\func.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def extract_scatter(
    polar_data,
    date,
    lat_range,
    lon_range,
    verbose=True,
    main_variable="wind_speed",
):
    """
    This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range.
    The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

    Args:
        polar_data (xarray.Dataset): The scatterometer dataset (ASCAT, HYSCAT etc).
        date (numpy.datetime64): The time of the scatterometer data.
        lat_range (tuple): The latitude range of the scatterometer data.
        lon_range (tuple): The longitude range of the scatterometer data.
        verbose (bool): If True, the function will print the progress of the extraction.
        main_variable (str): The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

    Returns:
        observation_times (numpy.ndarray): The time of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitude of the scatterometer data.
        observation_lons (numpy.ndarray): The longitude of the scatterometer data.
        observation_main_parameter (numpy.ndarray): main parameter extracted (wind_speed).

    """

    polar = polar_data.sel(
        time=slice(date, date),
        latitude=slice(lat_range[0], lat_range[1]),
        longitude=slice(lon_range[0], lon_range[1]),
    )

    seperated_scatter = savedataseperated(polar, polar[main_variable],verbose=verbose)

    observation_times = seperated_scatter[2]
    observation_lats = seperated_scatter[0]
    observation_lons = seperated_scatter[1]
    observation_wind_speeds = seperated_scatter[3]



    return (
        observation_times,
        observation_lats,
        observation_lons,
        observation_wind_speeds,
    )

fill_nans(images)

This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required

Returns:

Name Type Description
images ndarray

A 4D numpy array with NaN values replaced by zeros.

Source code in windscangeo\func.py
808
809
810
811
812
813
814
815
816
817
818
819
820
def fill_nans(images):
    """
    This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

    Returns:
        images (numpy.ndarray): A 4D numpy array with NaN values replaced by zeros.
    """
    images = np.nan_to_num(images, nan=0.0)
    print("INFO : Filled nans")
    return images

filter_invalid(images, numerical_data, min_nonzero_pixels=50)

This function filters out invalid images and corresponding numerical data based on two criteria: 1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty). 2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required
numerical_data dict

A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.

required
min_nonzero_pixels int

The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

50

Returns:

Name Type Description
filtered_images ndarray

A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.

filtered_numerical_data dict

A dictionary containing the numerical data associated with the valid images.

Source code in windscangeo\func.py
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
def filter_invalid(
    images,
    numerical_data,
    min_nonzero_pixels=50,
):

    """
    This function filters out invalid images and corresponding numerical data based on two criteria:
    1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty).
    2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.
        numerical_data (dict): A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.
        min_nonzero_pixels (int): The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

    Returns:
        filtered_images (numpy.ndarray): A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.
        filtered_numerical_data (dict): A dictionary containing the numerical data associated with the valid images.

    """
    # Sums of pixel values in each image
    sums_images = [np.nansum(x) for x in images]

    # Counts of non-zero pixels in each image
    nonzero_counts = [np.count_nonzero(x) for x in images]

    # Build a "mask_invalid" array of indices that fail any criterion:
    # 1) sum == 0 (completely empty)
    # 2) nonzero pixel count < min_nonzero_pixels (not enough data)

    mask_valid = np.where(
        (np.array(sums_images) != 0) & (np.array(nonzero_counts) >= min_nonzero_pixels)
    )[0]

    # Delete the invalid entries from each array
    filtered_numerical_data = {
        key: value[mask_valid] for key, value in numerical_data.items()
    }
    filtered_images = images[mask_valid]
    n_removed_images = len(images) - len(filtered_images)

    print(
        "INFO : Filtered invalid images. Removed {} entries.".format(
            n_removed_images
        )
    )
    return (
        filtered_images,
        filtered_numerical_data,
    )

filter_nighttime(observation_times, observation_lats, observation_lons, observation_wind_speeds, min_hour=10, max_hour=19, verbose=True)

This function filters the scatterometer data to only include observations that were made during daylight hours. The function checks the hour of each observation time and only keeps those that fall within the specified range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
observation_wind_speeds ndarray

The wind speeds of the scatterometer data.

required
min_hour int

The minimum hour of the day to include (default is 10).

10
max_hour int

The maximum hour of the day to include (default is 19).

19
verbose bool

If True, prints the number of valid scatterometer data points at daylight.

True

Returns:

Name Type Description
valid_times list

A list of valid observation times that fall within the specified hour range.

valid_lats list

A list of valid latitudes corresponding to the valid observation times

valid_lons list

A list of valid longitudes corresponding to the valid observation times.

valid_wind_speeds list

A list of valid wind speeds corresponding to the valid observation times.

Source code in windscangeo\func.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def filter_nighttime(
    observation_times,
    observation_lats,
    observation_lons,
    observation_wind_speeds,
    min_hour=10,
    max_hour=19,
    verbose=True,
):
    """
    This function filters the scatterometer data to only include observations that were made during daylight hours.
    The function checks the hour of each observation time and only keeps those that fall within the specified
    range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        observation_wind_speeds (numpy.ndarray): The wind speeds of the scatterometer data.
        min_hour (int): The minimum hour of the day to include (default is 10).
        max_hour (int): The maximum hour of the day to include (default is 19).
        verbose (bool): If True, prints the number of valid scatterometer data points at daylight.

    Returns:
        valid_times (list): A list of valid observation times that fall within the specified hour range.
        valid_lats (list): A list of valid latitudes corresponding to the valid observation times
        valid_lons (list): A list of valid longitudes corresponding to the valid observation times.
        valid_wind_speeds (list): A list of valid wind speeds corresponding to the valid observation times.

    """

    valid_times = []
    valid_lats = []
    valid_lons = []
    valid_wind_speeds = []

    for idx in range(len(observation_times)):
        only_hour = int(
            observation_times[idx].astype("datetime64[ns]").astype("str")[11:13]
        )
        if min_hour <= only_hour <= max_hour:
            valid_times.append(observation_times[idx])
            valid_lats.append(observation_lats[idx])
            valid_lons.append(observation_lons[idx])
            valid_wind_speeds.append(observation_wind_speeds[idx])

    if verbose:
        print(f"INFO : Total number of scatterometer data points at daylight : {len(valid_times)}")
    return valid_times, valid_lats, valid_lons, valid_wind_speeds

get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel='C01')

This function gets the nearest GOES-16 files from the time given. The function returns a list of urls to the files. The function uses the s3fs library to access the AWS GOES-16 data.

Parameters:

Name Type Description Default
time datetime[ns]

The time of the scatterometer data.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES-16 data is stored.

'noaa-goes16/ABI-L2-CMIPF'
goes_channel list

The channel of interest.

'C01'

Returns:

Name Type Description
urls list

A list of urls to the GOES-16 files.

Source code in windscangeo\func.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel="C01"):
    """
    This function gets the nearest GOES-16 files from the time given.
    The function returns a list of urls to the files.
    The function uses the s3fs library to access the AWS GOES-16 data.

    Args:
        time (numpy.datetime[ns]): The time of the scatterometer data.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES-16 data is stored.
        goes_channel (list): The channel of interest.

    Returns:
        urls (list): A list of urls to the GOES-16 files.


    """
    date_c = time.astype("datetime64[ns]")
    date = pd.to_datetime(date_c)
    date_str = date.strftime("%Y/%j/%H")
    min = int(date.strftime("%M"))
    min_range = [(min + i) % 60 for i in range(-6, 7)]
    min_range_str = [f"{x:02d}" for x in min_range]
    fs = s3fs.S3FileSystem(anon=True)
    # get the nearest goes file from time

    urls = []
    channel = goes_channel
    path = f"{goes_aws_url_folder}/{date_str}"
    files = fs.ls(path)
    filter_channel = [x for x in files if channel in x]
    if len(filter_channel) == 0:
        print(f"INFO :No file found for {channel} on day {date_str}, skipping file")
        return
    file = [x for x in filter_channel if x[73:75] in min_range_str]
    if len(file) == 0:
        print(
            f"INFO :No file found for {channel} on day {date_str} for minute {min}, skipping file"
        )
        return np.zeros(len(goes_channel))
    urls.append("s3://" + file[0])

    return urls

get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search, goes_image_size=128)

This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 .nc file.

Parameters:

Name Type Description Default
ds Dataset

The xarray dataset containing the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required
goes_image_size int

The size of the output image. Default is 128.

128

Returns:

Name Type Description
padded_image DataArray

A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

Source code in windscangeo\func.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search,goes_image_size=128):

    """
    This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 `.nc` file.

    Args:
        ds (xarray.Dataset): The xarray dataset containing the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.
        goes_image_size (int): The size of the output image. Default is 128.

    Returns:
        padded_image (xarray.DataArray): A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

    """
    index_row = np.where(
        lat_grd == lat_search,
    )
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    pixels_from_center = (goes_image_size-1) // 2
    mean_row = rows_goes.mean().astype(int)
    min_row = mean_row - pixels_from_center
    max_row = mean_row + pixels_from_center

    mean_col = columns_goes.mean().astype(int)
    min_col = mean_col - pixels_from_center
    max_col = mean_col + pixels_from_center

    if "CMI" in ds: # If using GOES-16 L2 processed data
        image = ds.CMI[min_row:max_row, min_col:max_col].values

    elif "Rad" in ds: #If using GOES-16 L1b data
        image = ds.Rad[min_row:max_row, min_col:max_col].values

    # debug
    # print(min_row,'= min_row', max_row,'= max_row', min_col, '= min_col', max_col, '= max_col')
    target_size = (goes_image_size, goes_image_size)

    padded_image = np.pad(
        image,
        (
            (
                (target_size[0] - image.shape[0]) // 2,
                (target_size[0] - image.shape[0] + 1) // 2,
            ),
            (
                (target_size[1] - image.shape[1]) // 2,
                (target_size[1] - image.shape[1] + 1) // 2,
            ),
        ),
        constant_values=0,
    )

    padded_image = xr.DataArray(padded_image, dims=("x", "y"))
    return padded_image

get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125)

Finds the corresponding GOES row and column indices for each scatterometer point using a BallTree for efficiency, and then filtering points to form a square bounding box.

Parameters:

Name Type Description Default
lat_grid ndarray

2D array of latitudes from the scatterometer data.

required
lon_grid ndarray

2D array of longitudes from the scatterometer data.

required
Goeslat ndarray

2D array of latitudes from the GOES data.

required
Goeslon ndarray

2D array of longitudes from the GOES data.

required
radius float

Radius in degrees to define the bounding box around each scatterometer point.

0.125

Returns: indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.

Source code in windscangeo\func.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125):
    """
    Finds the corresponding GOES row and column indices for each scatterometer point
    using a BallTree for efficiency, and then filtering points to form a square bounding box.

    Args:
        lat_grid (numpy.ndarray): 2D array of latitudes from the scatterometer data.
        lon_grid (numpy.ndarray): 2D array of longitudes from the scatterometer data.
        Goeslat (numpy.ndarray): 2D array of latitudes from the GOES data.
        Goeslon (numpy.ndarray): 2D array of longitudes from the GOES data.
        radius (float): Radius in degrees to define the bounding box around each scatterometer point.
    Returns:
        indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.


    """

    print("INFO : Calculating indices")
    # Flatten GOES data
    Goeslat_flat = Goeslat.flatten()
    Goeslon_flat = Goeslon.flatten()
    goes_points = np.column_stack((Goeslat_flat, Goeslon_flat))

    # Build BallTree with haversine distance
    goes_points_rad = np.radians(goes_points)
    goes_tree = BallTree(goes_points_rad, metric="haversine")

    # Flatten scatter grids
    lat_flat = lat_grid.flatten()
    lon_flat = lon_grid.flatten()
    scatter_points = np.column_stack((lat_flat, lon_flat))
    scatter_points_rad = np.radians(scatter_points)

    # Radius for broad-phase query: diagonal of the bounding box
    # Square box ±radius: diagonal = radius * sqrt(2)
    diag_radius = radius * np.sqrt(2)
    diag_radius_rad = np.radians(diag_radius)

    indices_array = np.empty(lat_flat.shape, dtype=object)
    goes_shape = Goeslat.shape

    for i, (lat_val, lon_val) in enumerate(zip(lat_flat, lon_flat)):
        # Broad-phase: query all points within diagonal distance
        candidate_indices = goes_tree.query_radius(
            np.array([scatter_points_rad[i]]), r=diag_radius_rad
        )[0]

        if candidate_indices.size == 0:
            # No points found, store empty
            indices_array[i] = (np.array([], dtype=int), np.array([], dtype=int))
            continue

        # Post-filter candidates to keep only those in the bounding box
        lat_min = lat_val - radius
        lat_max = lat_val + radius
        lon_min = lon_val - radius
        lon_max = lon_val + radius

        cand_lats = Goeslat_flat[candidate_indices]
        cand_lons = Goeslon_flat[candidate_indices]

        mask = (
            (cand_lats >= lat_min)
            & (cand_lats <= lat_max)
            & (cand_lons >= lon_min)
            & (cand_lons <= lon_max)
        )

        final_indices = candidate_indices[mask]

        # Convert these flat indices back to row,col
        rows, cols = np.unravel_index(final_indices, goes_shape)
        indices_array[i] = (rows, cols)

    # Reshape indices_array to the original shape
    indices_array = indices_array.reshape(lat_grid.shape)
    return indices_array

goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search)

This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

Parameters:

Name Type Description Default
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required

Returns:

Name Type Description
min_row int

The minimum row index of the GOES image.

max_row int

The maximum row index of the GOES image.

min_col int

The minimum column index of the GOES image.

max_col int

The maximum column index of the GOES image.

Source code in windscangeo\func.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
def goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search):
    """
    This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

    Args:
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.

    Returns:
        min_row (int): The minimum row index of the GOES image.
        max_row (int): The maximum row index of the GOES image.
        min_col (int): The minimum column index of the GOES image.
        max_col (int): The maximum column index of the GOES image.
    """

    index_row = np.where(lat_grd == lat_search)
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    min_row = rows_goes.min()
    max_row = rows_goes.max()

    min_col = columns_goes.min()
    max_col = columns_goes.max()

    return min_row, max_row, min_col, max_col

index_parallel(ds, ScatterDataset)

Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

Parameters:

Name Type Description Default
ScatterDataset

xarray Dataset containing scatterometer data.

required
scatter_name

Name for the output file.

required
output_path

Path to save the output file.

required

Returns:

Name Type Description
parallel_indice_values

2D array of tuples containing GOES row and column indices corresponding to scatterometer data.

Source code in windscangeo\func.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def index_parallel(ds, ScatterDataset):
    """
    Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

    Args:
        ScatterDataset: xarray Dataset containing scatterometer data.
        scatter_name: Name for the output file.
        output_path: Path to save the output file.

    Returns:
        parallel_indice_values: 2D array of tuples containing GOES row and column indices corresponding to scatterometer data.
    """

    create_folder("satellite_indices")
    ds_spatial_resolution = ds.spatial_resolution
    ds_spatial_resolution.replace(" ", "_")

    name_str = f"lat_{ScatterDataset.latitude.min().values}_{ScatterDataset.latitude.max().values}_lon_{ScatterDataset.longitude.min().values}_{ScatterDataset.longitude.max().values}_res_{ds_spatial_resolution}"
    name_str = name_str.replace(".", "_")
    if os.path.exists(
        f"./satellite_indices/{ds_spatial_resolution}_index.npy"
    ):
        parallel_index = np.load(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            allow_pickle=True,
        )

        return parallel_index

    else:
        print(
            "INFO : Satellite index file not found, creating new index file. This might take a while."
        )

        # Extract scatterometer latitudes and longitudes
        Latitudes_Scatter = ScatterDataset["latitude"].values
        Longitudes_Scatter = ScatterDataset["longitude"].values

        # Create a meshgrid of scatterometer coordinates
        lon_grid, lat_grid = np.meshgrid(Longitudes_Scatter, Latitudes_Scatter)

        # Extract GOES latitudes and longitudes
        Goeslat, Goeslon = calculate_degrees(ds)
        Goeslat[np.isnan(Goeslat)] = 999
        Goeslon[np.isnan(Goeslon)] = 999
        # Use the optimized get_indices function
        parallel_indice_values = get_indices(lat_grid, lon_grid, Goeslat, Goeslon)

        # Save the indices array
        np.save(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            parallel_indice_values,
        )

        return parallel_indice_values

package_data(images, numerical_data, filter=True, solar_conversion=False, verbose=True)

This function packages the images and numerical data into a format that can be used for training a machine learning model. The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data) The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True. The function will return the images and numerical data in a numpy array format.

Parameters:

Name Type Description Default
images ndarray

The GOES images corresponding to the observation data.

required
numerical_data dict

A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".

required
filter bool

If True, the function will filter out invalid images and fill in Nan values. Default is True.

True
solar_conversion bool

If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)

False
verbose bool

If True, the function will print progress information. Default is True.

True

Returns:

Name Type Description
images ndarray

The GOES images corresponding to the observation data.

numerical_data ndarray

The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

Source code in windscangeo\func.py
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
def package_data(
    images,
    numerical_data,
    filter=True,
    solar_conversion=False,
    verbose=True
):
    """
    This function packages the images and numerical data into a format that can be used for training a machine learning model.
    The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data)
    The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True.
    The function will return the images and numerical data in a numpy array format.

    Args:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (dict): A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".
        filter (bool): If True, the function will filter out invalid images and fill in Nan values. Default is True.
        solar_conversion (bool): If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)
        verbose (bool): If True, the function will print progress information. Default is True.

    Returns:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (numpy.ndarray): The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

    """
    if filter:
        (images, numerical_data) = filter_invalid(images, numerical_data)
        images = fill_nans(images)

    if solar_conversion:
        observation_lats = numerical_data["observation_lats"]
        observation_lons = numerical_data["observation_lons"]
        observation_times = numerical_data["observation_times"]

        sza, saa = vectorized_solar_angles(
            observation_lats, observation_lons, observation_times
        )

        sza_rad = np.deg2rad(sza)
        sza_sin = np.sin(sza_rad)
        sza_cos = np.cos(sza_rad)

        saa_rad = np.deg2rad(saa)
        saa_sin = np.sin(saa_rad)
        saa_cos = np.cos(saa_rad)

        # Add the solar angles to the numerical data dictionary

        numerical_data["sza_sin"] = sza_sin
        numerical_data["sza_cos"] = sza_cos
        numerical_data["saa_sin"] = saa_sin
        numerical_data["saa_cos"] = saa_cos

        print("Data Preparation : converted to solar angles (sza, saa)")
        print("Data Preparation : returning images, numerical_data")
        return images, numerical_data

    else:

        return images, numerical_data

save_overpass_time(time_list, name_scatter)

This function prints the overpass time of the scatterometer.

Parameters:

Name Type Description Default
time_list ndarray

The measurement time values of the scatterometer data.

required
name_scatter str

The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

required

Returns:

Type Description

None

Source code in windscangeo\func.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def save_overpass_time(time_list,name_scatter):
    """
    This function prints the overpass time of the scatterometer.

    Args:
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        name_scatter (str): The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

    Returns:
        None 

    """
    formated_time = time_list.astype('datetime64[ns]')
    hour_minute = formated_time.astype('datetime64[m]')
    unique_hour_minute = np.unique(hour_minute)

    filtered = [unique_hour_minute[0]]

    delta = np.timedelta64(1, 'h')

    for time in unique_hour_minute[1:]:
        if time - filtered[-1] >= delta:
            filtered.append(time)

    time_only = []
    for time in filtered:
        time = str(time).split('T')[1]
        time_only.append(time)
    print(f"ORBIT : {name_scatter} overpass time : {time_only}")

savedataseperated(ScatterData, main_parameter, verbose=True)

This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel of the scatterometer data and saves it to a numpy file.

Parameters:

Name Type Description Default
ScatterData Dataset

The ASCAT dataset containing the scatterometer data.

required
main_parameter DataArray

The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

required

Returns:

lat_list (numpy.ndarray): The latitude values of the scatterometer data.
lon_list (numpy.ndarray): The longitude values of the scatterometer data.
time_list (numpy.ndarray): The measurement time values of the scatterometer data.
main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

this function saves the data locally to a folder called data_processed_scat

Source code in windscangeo\func.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def savedataseperated(ScatterData, main_parameter,verbose=True):
    """
    This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel
    of the scatterometer data and saves it to a numpy file.

    Args:
        ScatterData (xarray.Dataset): The ASCAT dataset containing the scatterometer data.
        main_parameter (xarray.DataArray): The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

    Returns:

        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

    this function saves the data locally to a folder called data_processed_scat
    """
    lat_full, lon_full, time_full = ScatterData.indexes.values()
    measurement_time_full = ScatterData.measurement_time

    lat_full = np.array(lat_full)
    lon_full = np.array(lon_full)
    measurement_time_full = np.array(measurement_time_full)
    main_parameter = np.array(main_parameter)

    index = np.argwhere(~np.isnan(main_parameter))

    index_list = []
    lat_list = []
    lon_list = []
    time_list = []
    wind_speed_list = []

    name_scatter = ScatterData.source

    for t, i, j in index:

        # print(t,'= time', i,'=row', j, '=column')
        index_list.append((t, i, j))

        # print(measurement_time_full[t, i, j].astype('datetime64[ns]'))
        time_list.append(measurement_time_full[t, i, j])

        # print(lat_full[i])
        lat_list.append(lat_full[i])

        # print(lon_full[j])
        lon_list.append(lon_full[j])

        # print(AllWindSpeeds[t, i, j])
        wind_speed_list.append(main_parameter[t, i, j])

    lat_list = np.array(lat_list)
    lon_list = np.array(lon_list)
    time_list = np.array(time_list)
    wind_speed_list = np.array(wind_speed_list)

    lat_list, lon_list, time_list, wind_speed_list = sort_by_time(
        lat_list, lon_list, time_list, wind_speed_list
    )
    if verbose:
        save_overpass_time(time_list,name_scatter)    

    return lat_list, lon_list, time_list, wind_speed_list

sort_by_time(lat_list, lon_list, time_list, wind_speed_list)

This function sorts the output of savedataseperated() by time. This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

Parameters:

Name Type Description Default
lat_list ndarray

The latitude values of the scatterometer data.

required
lon_list ndarray

The longitude values of the scatterometer data.

required
time_list ndarray

The measurement time values of the scatterometer data.

required
wind_speed_list ndarray

The wind speed values of the scatterometer data.

required

Returns:

Name Type Description
lat_list_sorted ndarray

The sorted latitude values of the scatterometer data.

lon_list_sorted ndarray

The sorted longitude values of the scatterometer data.

time_list_sorted ndarray

The sorted measurement time values of the scatterometer data.

wind_speed_list_sorted ndarray

The sorted wind speed values of the scatterometer data.

Source code in windscangeo\func.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def sort_by_time(lat_list, lon_list, time_list, wind_speed_list):
    """
    This function sorts the output of savedataseperated() by time.
    This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

    Args:
        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        wind_speed_list (numpy.ndarray): The wind speed values of the scatterometer data.

    Returns:
        lat_list_sorted (numpy.ndarray): The sorted latitude values of the scatterometer data.
        lon_list_sorted (numpy.ndarray): The sorted longitude values of the scatterometer data.
        time_list_sorted (numpy.ndarray): The sorted measurement time values of the scatterometer data.
        wind_speed_list_sorted (numpy.ndarray): The sorted wind speed values of the scatterometer data.

    """
    # Get the indices that would sort the measurement_time array
    sorted_indices = np.argsort(time_list)

    # Reorder the arrays using the sorted indices
    time_list_sorted = time_list[sorted_indices]
    lat_list_sorted = lat_list[sorted_indices]
    lon_list_sorted = lon_list[sorted_indices]
    speed_list_sorted = wind_speed_list[sorted_indices]

    return lat_list_sorted, lon_list_sorted, time_list_sorted, speed_list_sorted

vectorized_solar_angles(lat, lon, time_utc)

This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

Parameters:

Name Type Description Default
lat ndarray

The latitude values of the scatterometer data.

required
lon ndarray

The longitude values of the scatterometer data.

required
time_utc ndarray

The observation times in UTC.

required

Returns:

Name Type Description
sza ndarray

The solar zenith angle in degrees.

saa ndarray

The solar azimuth angle in degrees.

Source code in windscangeo\func.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def vectorized_solar_angles(lat, lon, time_utc):

    """
    This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

    Args:
        lat (numpy.ndarray): The latitude values of the scatterometer data.
        lon (numpy.ndarray): The longitude values of the scatterometer data.
        time_utc (numpy.ndarray): The observation times in UTC.

    Returns:
        sza (numpy.ndarray): The solar zenith angle in degrees.
        saa (numpy.ndarray): The solar azimuth angle in degrees.
    """

    # Convert time to Julian Day
    timestamp = pd.to_datetime(time_utc).tz_localize(None)
    jd = (
        timestamp.astype("datetime64[ns]").astype(np.int64) / 86400000000000 + 2440587.5
    )
    d = jd - 2451545.0  # Days since J2000

    # Mean longitude, mean anomaly, ecliptic longitude
    g = np.deg2rad((357.529 + 0.98560028 * d) % 360)  # Mean anomaly
    q = np.deg2rad((280.459 + 0.98564736 * d) % 360)  # Mean longitude
    L = (q + np.deg2rad(1.915) * np.sin(g) + np.deg2rad(0.020) * np.sin(2 * g)) % (
        2 * np.pi
    )  # Ecliptic long

    # Obliquity of the ecliptic
    e = np.deg2rad(23.439 - 0.00000036 * d)

    # Sun declination
    sin_delta = np.sin(e) * np.sin(L)
    delta = np.arcsin(sin_delta)

    # Equation of time (in minutes)
    E = 229.18 * (
        0.000075
        + 0.001868 * np.cos(g)
        - 0.032077 * np.sin(g)
        - 0.014615 * np.cos(2 * g)
        - 0.040849 * np.sin(2 * g)
    )

    # Convert time to fractional hours (UTC)
    fractional_hour = timestamp.hour + timestamp.minute / 60 + timestamp.second / 3600

    # Solar time correction
    time_offset = E + 4 * lon  # lon in degrees
    tst = fractional_hour * 60 + time_offset  # True Solar Time in minutes
    ha = np.deg2rad((tst / 4 - 180) % 360)  # Hour angle in radians

    # Convert lat/lon to radians
    lat_rad = np.deg2rad(lat)

    # Solar zenith angle
    cos_zenith = np.sin(lat_rad) * np.sin(delta) + np.cos(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    zenith = np.rad2deg(np.arccos(np.clip(cos_zenith, -1, 1)))  # in degrees

    # Solar saa angle
    sin_saa = -np.sin(ha) * np.cos(delta)
    cos_saa = np.cos(lat_rad) * np.sin(delta) - np.sin(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    saa = np.rad2deg(np.arctan2(sin_saa, cos_saa))
    saa = (saa + 360) % 360  # Normalize

    return zenith, saa

H5pyDataset

Bases: Dataset

A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory. Need to work on Zarr integration for better performance

Parameters:

Name Type Description Default
h5_file_path str

Path to the HDF5 file.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class H5pyDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory.
    Need to work on Zarr integration for better performance

    Args:
        h5_file_path (str): Path to the HDF5 file.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, h5_file_path, transform=None):
        self.h5_file_path = h5_file_path
        self.transform = transform
        self.file = None  # Will be initialized per worker
        with h5py.File(self.h5_file_path, 'r') as f:
            self.length = len(f['targets'])

    def _ensure_file(self):
        if self.file is None:
            self.file = h5py.File(self.h5_file_path, 'r')

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        self._ensure_file()

        image = self.file['images'][idx]
        target = self.file['targets'][idx]

        image = torch.tensor(image, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, target

    def __del__(self):
        if self.file:
            self.file.close()

Normalize

Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch

Parameters:

Name Type Description Default
mean list or ndarray

Mean values for normalization.

required
std list or ndarray

Standard deviation values for normalization.

required
Source code in windscangeo\func_ml.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Normalize:

    """
    Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch 

    Args:
        mean (list or np.ndarray): Mean values for normalization.
        std (list or np.ndarray): Standard deviation values for normalization.
    """
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def __call__(self, x):
        return (x - self.mean) / self.std

conventional_dataset

Bases: Dataset

A PyTorch Dataset for loading data using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images.

required
targets list or ndarray

List or array of targets corresponding to the images.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class conventional_dataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images.
        targets (list or np.ndarray): List or array of targets corresponding to the images.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        # 2) Target for sample "idx"
        target = torch.tensor(self.targets[idx], dtype=torch.float32)

        return image, target

conventional_dataset_inference

Bases: Dataset

A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images to be used for inference.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class conventional_dataset_inference(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images to be used for inference.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images,transform=None):
        self.images = images
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        return image

error_plot(best_val_outputs, best_val_labels, path_folder=None)

Plot a scatter plot of model outputs vs true labels for the validation dataset.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the plot. If None, the plot will not be saved.

None
Source code in windscangeo\func_ml.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def error_plot(best_val_outputs, best_val_labels, path_folder=None):
    """
    Plot a scatter plot of model outputs vs true labels for the validation dataset.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        path_folder (str, optional): Path to save the plot. If None, the plot will not be saved.
    """

    max_all = max(max(best_val_outputs), max(best_val_labels))

    plt.figure(figsize=(5, 5))
    plt.plot(best_val_labels, best_val_outputs, "o")
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xlim(0, max_all)
    plt.ylim(0, max_all)
    plt.xlabel("True Labels")
    plt.ylabel("Model Output")
    plt.title("Model Output vs True Labels in test dataset")
    plt.xticks(np.arange(0, 30, 5))
    plt.yticks(np.arange(0, 30, 5))
    plt.plot(
        [min(best_val_labels), max(best_val_labels)],
        [min(best_val_labels), max(best_val_labels)],
        "r--",
    )  # y = x reference line
    if path_folder:
        plt.savefig(os.path.join(path_folder, "scatter_plot.png"))

plot_cloud_cover(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be used for cloud cover calculation.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def plot_cloud_cover(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be used for cloud cover calculation.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = np.mean(images, axis=(2,3))
    threshold = 0.11
    cloud_mask = np.where(mean_images > threshold, 1, 0)
    cloud_mask = cloud_mask.reshape(160,340)

    plot_cloud_mask(lat_inference,lon_inference,cloud_mask,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice)

    percentage_cloud = np.sum(cloud_mask)/cloud_mask.size
    print('Cloud coverage : ',percentage_cloud*100,'%')

    return cloud_mask, percentage_cloud

plot_cloud_mask(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Cloud mask data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def plot_cloud_mask(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Cloud mask data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """

    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, wind_speeds_inference,
        shading='auto', cmap='Blues',
        vmin=0, vmax=1
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='0 = Clear, 1 = Cloudy')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)

    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))

    # Plot each buoy in a single color
    for i, buoy_id in enumerate(unique_buoys):
        # Pick a distinct color from tab20
        color = color_map(i)

        # Identify the indices belonging to this buoy
        mask = np.array(buoy_name_flat) == buoy_id

        # Scatter just those points
        ax.scatter(
            np.array(buoy_lon)[mask],
            np.array(buoy_lat)[mask],
            s=100,
            color=color,            # Set the fill color
            edgecolor='black',
            linewidth=1,
            zorder=12,
            label=buoy_id           # Use buoy_id as the legend label
        )

    # Create the legend
    leg = ax.legend(
        title="Buoy Stations",
        loc="upper right",
        bbox_to_anchor=(1.0, 1.0),
        bbox_transform=ax.transAxes
    )

    # Ensure legend is above all other layers
    leg.set_zorder(999)

    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'Cloud Mask at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())


    plt.savefig(f'{path_folder}/plot_cloud_mask.png')

plot_goes_image(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def plot_goes_image(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = images[:,:,64,64]
    mean_images = mean_images.ravel()
    mean_images = mean_images.reshape(160,340)


    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())


    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, mean_images,
        shading='auto',
        vmin=0,
        vmax=1,
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='Brightness Temperature (K) - C01')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)


    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))


    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'GOES image at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())

    plt.savefig(f'{path_folder}/plot_goes_image.png')

    mean_images = np.array(mean_images)

    return mean_images

plot_save_loss(best_val_outputs, best_val_labels, train_losses, val_losses, path_folder, saving=False)

Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
train_losses list

List of training losses per epoch.

required
val_losses list

List of validation losses per epoch.

required
path_folder str

Path to save the plot and optionally the outputs and labels.

required
saving bool

If True, save the best validation outputs and labels. Default is False.

False
Source code in windscangeo\func_ml.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def plot_save_loss(
    best_val_outputs,
    best_val_labels,
    train_losses,
    val_losses,
    path_folder,
    saving=False,
):
    """
    Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        train_losses (list): List of training losses per epoch.
        val_losses (list): List of validation losses per epoch.
        path_folder (str): Path to save the plot and optionally the outputs and labels.
        saving (bool, optional): If True, save the best validation outputs and labels. Default is False.
    """
    # After training, save only the best validation outputs and labels
    if saving:
        np.save(
            os.path.join(path_folder, "best_validation_outputs.npy"), best_val_outputs
        )
        np.save(
            os.path.join(path_folder, "best_validation_labels.npy"), best_val_labels
        )

    num_epochs = len(train_losses)
    # Plotting the training and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
    plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    text_str = f"num_epochs = {num_epochs}, train loss = {train_losses[-1]:.2f}, validation loss = {val_losses[-1]:.2f}"
    plt.text(
        0.05,
        0.05,
        text_str,
        ha="left",
        va="bottom",
        transform=plt.gca().transAxes,  # Ensures the coordinates are relative to the axes (0 to 1 range)
    )
    plt.title("Training and Validation Loss Over Epochs")
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(path_folder, "loss_plot.png"))

plot_wind_speeds(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Wind speed data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def plot_wind_speeds(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Wind speed data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    time_str = time_choice.strftime('%Y-%m-%d %H:%M:%S')
    date, time = time_str.split(' ')
    lat = lat_inference
    lon = lon_inference
    wind_speeds = wind_speeds_inference


    buoy_names = buoy_name
    buoy_lat = buoy_lat
    buoy_lon = buoy_lon


    # nighttime mask

    lat_flat = lat.flatten()
    lon_flat = lon.flatten()
    time_flat = np.full(len(lat_flat), pd.Timestamp(f'{date} {time}'), dtype='datetime64[ns]')


    sza, saa = vectorized_solar_angles(lat_flat, lon_flat, time_flat)
    saa = np.reshape(saa,lat.shape)
    sza = np.reshape(sza,lat.shape)

    night_time_mask = np.where(sza > 90, 1, np.nan)
    cmap = ListedColormap(['white'])

    ###############

    min_lon, max_lon, min_lat, max_lat = -70, 0, -12, 20

    # add coastlines and gridlines
    fig = plt.figure(figsize=(22, 10), dpi=100)
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    levels = np.arange(0, 13, 1)
    line_colour = 'black'
    line_colours = ['black' for i in levels]
    ax.title.set_text(f'Wind Speed prediction from C01 GOES image (m/s) {date} {time} ')
    ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15,zorder = 5)
    ax.contourf(lon, lat, night_time_mask, transform=ccrs.PlateCarree(),cmap=cmap,zorder = 10)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.LAND, edgecolor='black',zorder= 20)
    ax.set_xticks(np.arange(-70, 1, 10), crs=ccrs.PlateCarree())
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0f}°E'))
    ax.set_yticks(np.arange(-15, 21, 5), crs=ccrs.PlateCarree())
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0f}°N'))
    ax.set_extent((-70, 0, -12, 20))
    ax.hlines(0, -70, 0, color='red', linewidth=1.5, linestyle='--', zorder= 10)
    fig.colorbar(ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15), ax=ax, orientation='vertical', aspect=50, label='Wind Speed (m/s)')
    ax.gridlines(color='gray', linestyle='--', alpha=0.5,zorder= 999)

    for buoy in range(len(buoy_names)):
        lon_b, lat_b = buoy_lon[buoy], buoy_lat[buoy]

        # Skip if out of bounds
        if not (min_lon <= lon_b <= max_lon and min_lat <= lat_b <= max_lat):
            continue

        ax.plot(lon_b, lat_b, 'o', color='red', markersize=5, transform=ccrs.PlateCarree(), label=buoy_names[buoy],zorder = 999)

        ax.text(
            lon_b + 0.2, lat_b + 0.2, buoy_names[buoy],
            fontsize=9, color='white',
            transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='black', edgecolor='none', boxstyle='square,pad=0.2'),zorder = 999
        )
    plt.savefig(f'{path_folder}/plot_wind_speeds.png')

rmse_per_range(model_output, target, path_folder)

Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

Parameters:

Name Type Description Default
model_output list or ndarray

Model outputs for the validation dataset.

required
target list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the CSV file.

required

Returns: pd.DataFrame: DataFrame containing the RMSE and count for each range.

Source code in windscangeo\func_ml.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def rmse_per_range(model_output, target,path_folder):
    """
    Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

    Args:
        model_output (list or np.ndarray): Model outputs for the validation dataset.
        target (list or np.ndarray): True labels for the validation dataset.
        path_folder (str): Path to save the CSV file.
    Returns:
        pd.DataFrame: DataFrame containing the RMSE and count for each range.
    """

    max_target = np.max(target)
    bins = np.arange(0, max_target, 1)
    rmse = np.zeros(len(bins))
    count = np.zeros(len(bins))
    results = []

    for i in range(len(bins)-1):
        idx = np.where((target >= bins[i]) & (target <= bins[i+1]))
        rmse[i] = np.sqrt(np.mean((model_output[idx] - target[idx])**2))
        count[i] = len(idx[0])
        print(f"EVAL : Range {bins[i]} m/s - {bins[i+1]} m/s: RMSE = {rmse[i]}, count = {int(count[i])}")
        results.append({'bin_start': bins[i], 'bin_end': bins[i+1], 'rmse': rmse[i], 'count': count[i]})

    df = pd.DataFrame(results)
    df.to_csv(f'{path_folder}/rmse_per_range.csv')
    return df

vectorized_solar_angles(lat, lon, time_utc)

This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

Parameters:

Name Type Description Default
lat ndarray

The latitude values of the scatterometer data.

required
lon ndarray

The longitude values of the scatterometer data.

required
time_utc ndarray

The observation times in UTC.

required

Returns:

Name Type Description
sza ndarray

The solar zenith angle in degrees.

saa ndarray

The solar azimuth angle in degrees.

Source code in windscangeo\func_ml.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def vectorized_solar_angles(lat, lon, time_utc):

    """
    This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

    Args:
        lat (numpy.ndarray): The latitude values of the scatterometer data.
        lon (numpy.ndarray): The longitude values of the scatterometer data.
        time_utc (numpy.ndarray): The observation times in UTC.

    Returns:
        sza (numpy.ndarray): The solar zenith angle in degrees.
        saa (numpy.ndarray): The solar azimuth angle in degrees.
    """

    # Convert time to Julian Day
    timestamp = pd.to_datetime(time_utc).tz_localize(None)
    jd = (
        timestamp.astype("datetime64[ns]").astype(np.int64) / 86400000000000 + 2440587.5
    )
    d = jd - 2451545.0  # Days since J2000

    # Mean longitude, mean anomaly, ecliptic longitude
    g = np.deg2rad((357.529 + 0.98560028 * d) % 360)  # Mean anomaly
    q = np.deg2rad((280.459 + 0.98564736 * d) % 360)  # Mean longitude
    L = (q + np.deg2rad(1.915) * np.sin(g) + np.deg2rad(0.020) * np.sin(2 * g)) % (
        2 * np.pi
    )  # Ecliptic long

    # Obliquity of the ecliptic
    e = np.deg2rad(23.439 - 0.00000036 * d)

    # Sun declination
    sin_delta = np.sin(e) * np.sin(L)
    delta = np.arcsin(sin_delta)

    # Equation of time (in minutes)
    E = 229.18 * (
        0.000075
        + 0.001868 * np.cos(g)
        - 0.032077 * np.sin(g)
        - 0.014615 * np.cos(2 * g)
        - 0.040849 * np.sin(2 * g)
    )

    # Convert time to fractional hours (UTC)
    fractional_hour = timestamp.hour + timestamp.minute / 60 + timestamp.second / 3600

    # Solar time correction
    time_offset = E + 4 * lon  # lon in degrees
    tst = fractional_hour * 60 + time_offset  # True Solar Time in minutes
    ha = np.deg2rad((tst / 4 - 180) % 360)  # Hour angle in radians

    # Convert lat/lon to radians
    lat_rad = np.deg2rad(lat)

    # Solar zenith angle
    cos_zenith = np.sin(lat_rad) * np.sin(delta) + np.cos(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    zenith = np.rad2deg(np.arccos(np.clip(cos_zenith, -1, 1)))  # in degrees

    # Solar saa angle
    sin_saa = -np.sin(ha) * np.cos(delta)
    cos_saa = np.cos(lat_rad) * np.sin(delta) - np.sin(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    saa = np.rad2deg(np.arctan2(sin_saa, cos_saa))
    saa = (saa + 360) % 360  # Normalize

    return zenith, saa

Block

Bases: Module

A basic block for ResNet architecture. This block consists of two convolutional layers with batch normalization and ReLU activation. The first layer applies a 3x3 convolution, and the second layer applies another 3x3 convolution. The block also supports downsampling through an optional identity downsample layer. The expansion factor is set to 1, meaning the output channels are the same as the input channels.

taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py

Source code in windscangeo\Models.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class Block(nn.Module):
    """
    A basic block for ResNet architecture.
    This block consists of two convolutional layers with batch normalization
    and ReLU activation. The first layer applies a 3x3 convolution, and the
    second layer applies another 3x3 convolution. The block also supports
    downsampling through an optional identity downsample layer. The expansion
    factor is set to 1, meaning the output channels are the same as the input
    channels.

    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    """

    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
      identity = x.clone()

      x = self.relu(self.batch_norm2(self.conv1(x)))
      x = self.batch_norm2(self.conv2(x))

      if self.i_downsample is not None:
          identity = self.i_downsample(identity)
      print(x.shape)
      print(identity.shape)
      x += identity
      x = self.relu(x)
      return x

Bottleneck

Bases: Module

taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py A bottleneck block for ResNet architecture. This block consists of three convolutional layers with batch normalization and ReLU activation. The first layer reduces the number of channels, the second layer applies a 3x3 convolution, and the third layer expands the number of channels back to the original size. The block also supports downsampling through an optional identity downsample layer.

Source code in windscangeo\Models.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class Bottleneck(nn.Module):
    """
    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    A bottleneck block for ResNet architecture.
    This block consists of three convolutional layers with batch normalization
    and ReLU activation. The first layer reduces the number of channels,
    the second layer applies a 3x3 convolution, and the third layer expands
    the number of channels back to the original size. The block also supports
    downsampling through an optional identity downsample layer.
    """
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))

        x = self.relu(self.batch_norm2(self.conv2(x)))

        x = self.conv3(x)
        x = self.batch_norm3(x)

        #downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        #add identity
        x+=identity
        x=self.relu(x)

        return x

ConventionalCNN

Bases: Module

A simple CNN for image regression tasks. This model consists of a series of convolutional layers followed by fully connected layers. It is designed to process images and output a single regression value (e.g., wind speed).

Source code in windscangeo\Models.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ConventionalCNN(nn.Module):
    """
    A simple CNN for image regression tasks.
    This model consists of a series of convolutional layers followed by
    fully connected layers. It is designed to process images and output a
    single regression value (e.g., wind speed).
    """
    def __init__(
        self,
        image_height: int,
        image_width: int,
        features_cnn: list[int],
        kernel_size: int,
        in_channels: int,
        activation_cnn: nn.Module = nn.ReLU(),
        activation_final: nn.Module = nn.Identity(),
        stride: int = 1,
        dropout_rate: float = 0.2,
    ):
        super().__init__()
        self.activation_cnn = activation_cnn
        self.activation_final = activation_final
        self.dropout_rate = dropout_rate

        # ------- Convolutional backbone -------
        self.convs = nn.ModuleList()
        for feature in features_cnn:
            self.convs.extend(
                [
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=feature,
                        kernel_size=kernel_size,
                        padding=1,
                        stride=stride,
                    ),
                    self.activation_cnn,
                    nn.MaxPool2d(kernel_size=2),
                    nn.Dropout(self.dropout_rate),
                ]
            )
            in_channels = feature

        # ------- Classifier / regressor head -------
        self.flattened_size = self._get_flattened_size(image_height, image_width)
        self.fc_cnn = nn.Linear(self.flattened_size, 64)
        self.dropout_cnn = nn.Dropout(self.dropout_rate)
        self.head = nn.Sequential(
            nn.Linear(64, 16),
            self.activation_cnn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(16, 1),
            self.activation_final,
        )

    def _get_flattened_size(self, h, w):
        x = torch.zeros(1, self.convs[0].in_channels, h, w)
        for layer in self.convs:
            x = layer(x)
        return x.numel()

    def forward(self, image):
        x = image
        for layer in self.convs:
            x = layer(x)
        x = x.view(x.size(0), -1)              # flatten
        x = self.activation_cnn(self.fc_cnn(x))
        x = self.dropout_cnn(x)
        out = self.head(x)
        return out

H5pyDataset

Bases: Dataset

A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory. Need to work on Zarr integration for better performance

Parameters:

Name Type Description Default
h5_file_path str

Path to the HDF5 file.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class H5pyDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data from an HDF5 file. This is useful when dealing with large datasets that do not fit into memory.
    Need to work on Zarr integration for better performance

    Args:
        h5_file_path (str): Path to the HDF5 file.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, h5_file_path, transform=None):
        self.h5_file_path = h5_file_path
        self.transform = transform
        self.file = None  # Will be initialized per worker
        with h5py.File(self.h5_file_path, 'r') as f:
            self.length = len(f['targets'])

    def _ensure_file(self):
        if self.file is None:
            self.file = h5py.File(self.h5_file_path, 'r')

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        self._ensure_file()

        image = self.file['images'][idx]
        target = self.file['targets'][idx]

        image = torch.tensor(image, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, target

    def __del__(self):
        if self.file:
            self.file.close()

Img2Seq

Bases: Module

This layers takes a batch of images as input and returns a batch of sequences

Shape

input: (b, h, w, c) output: (b, s, d)

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class Img2Seq(nn.Module):
    """
    This layers takes a batch of images as input and
    returns a batch of sequences

    Shape:
        input: (b, h, w, c)
        output: (b, s, d)

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    def __init__(self, img_size, patch_size, n_channels, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        n_tokens = nh * nw

        token_dim = patch_size[0] * patch_size[1] * n_channels
        self.linear = nn.Linear(token_dim, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_emb = nn.Parameter(torch.randn(n_tokens, d_model))

    def __call__(self, batch):
        batch = patchify(batch, self.patch_size)

        b, c, nh, nw, ph, pw = batch.shape

        # Flattening the patches
        batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
        batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

        batch = self.linear(batch)
        cls = self.cls_token.expand([b, -1, -1])
        emb = batch + self.pos_emb

        return torch.cat([cls, emb], axis=1)

Normalize

Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch

Parameters:

Name Type Description Default
mean list or ndarray

Mean values for normalization.

required
std list or ndarray

Standard deviation values for normalization.

required
Source code in windscangeo\func_ml.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Normalize:

    """
    Normalize the input tensor by subtracting the mean and dividing by the standard deviation. Done per batch 

    Args:
        mean (list or np.ndarray): Mean values for normalization.
        std (list or np.ndarray): Standard deviation values for normalization.
    """
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def __call__(self, x):
        return (x - self.mean) / self.std

ResNet

Bases: Module

A ResNet model for image classification or regression tasks. This model consists of an initial convolutional layer, followed by a series of residual blocks, and a fully connected layer for classification or regression. The number of residual blocks in each layer is specified by the layer_list parameter. taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py

Source code in windscangeo\Models.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class ResNet(nn.Module):

    """
    A ResNet model for image classification or regression tasks.
    This model consists of an initial convolutional layer, followed by a series
    of residual blocks, and a fully connected layer for classification or regression.
    The number of residual blocks in each layer is specified by the `layer_list` parameter.
    taken from https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
    """
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)

        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*ResBlock.expansion, num_classes)

    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []

        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )

        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion

        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))

        return nn.Sequential(*layers)

ViT

Bases: Module

Vision Transformer (ViT) model for image classification or regression tasks. This model consists of an image-to-sequence layer, a transformer encoder, and a multi-layer perceptron (MLP) head for classification or regression.

Taken from # https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class ViT(nn.Module):
    """
    Vision Transformer (ViT) model for image classification or regression tasks.
    This model consists of an image-to-sequence layer, a transformer encoder,
    and a multi-layer perceptron (MLP) head for classification or regression.

    Taken from # https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    def __init__(
        self,
        img_size,
        patch_size,
        n_channels,
        d_model,
        nhead,
        dim_feedforward,
        blocks,
        mlp_head_units,
        n_classes,
    ):
        super().__init__()
        """
        Args:
            img_size: Size of the image
            patch_size: Size of the patch
            n_channels: Number of image channels
            d_model: The number of features in the transformer encoder
            nhead: The number of heads in the multiheadattention models
            dim_feedforward: The dimension of the feedforward network model in the encoder
            blocks: The number of sub-encoder-layers in the encoder
            mlp_head_units: The hidden units of mlp_head
            n_classes: The number of output classes
        """
        self.img2seq = Img2Seq(img_size, patch_size, n_channels, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, activation="gelu", batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, blocks
        )
        self.mlp = get_mlp(d_model, mlp_head_units, n_classes)

        self.output = nn.Identity() # For regression

    def forward(self, batch):

        batch = self.img2seq(batch)
        batch = self.transformer_encoder(batch)
        batch = batch[:, 0, :]
        batch = self.mlp(batch)
        output = self.output(batch)
        return output

conventional_dataset

Bases: Dataset

A PyTorch Dataset for loading data using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images.

required
targets list or ndarray

List or array of targets corresponding to the images.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class conventional_dataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images.
        targets (list or np.ndarray): List or array of targets corresponding to the images.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        # 2) Target for sample "idx"
        target = torch.tensor(self.targets[idx], dtype=torch.float32)

        return image, target

conventional_dataset_inference

Bases: Dataset

A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

Parameters:

Name Type Description Default
images list or ndarray

List or array of images to be used for inference.

required
transform callable

A function/transform to apply to the images.

None
Source code in windscangeo\func_ml.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class conventional_dataset_inference(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading data for inference (no lable) using regular numpy arrays.

    Args:
        images (list or np.ndarray): List or array of images to be used for inference.
        transform (callable, optional): A function/transform to apply to the images.
    """
    def __init__(self, images,transform=None):
        self.images = images
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # 1) Image
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)

        return image

buoy_data_extract(folder_path, polar_data, date)

Extracts buoy data from a specified folder and returns arrays of latitude, longitude, time, wind speed, and buoy names.

Parameters:

Name Type Description Default
folder_path str

Path to the folder containing buoy data files.

required
polar_data Dataset

Polar data containing latitude and longitude information. Used to snap buoy data to the nearest polar grid points.

required
date str

Date for which to extract buoy data, in 'YYYY-MM-DD' format.

required

Returns:

Name Type Description
buoy_lat ndarray

Array of buoy latitudes snapped to the nearest polar grid points.

buoy_lon ndarray

Array of buoy longitudes snapped to the nearest polar grid points

buoy_time ndarray

Array of buoy observation times.

buoy_wind_speed ndarray

Array of buoy wind speeds.

buoy_name ndarray

Array of buoy names.

Source code in windscangeo\func_inference.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def buoy_data_extract(folder_path, polar_data, date):
    """ Extracts buoy data from a specified folder and returns arrays of latitude, longitude, time, wind speed, and buoy names.

    Args:
        folder_path (str): Path to the folder containing buoy data files.
        polar_data (xarray.Dataset): Polar data containing latitude and longitude information. Used to snap buoy data to the nearest polar grid points.
        date (str): Date for which to extract buoy data, in 'YYYY-MM-DD' format.

    Returns:
        buoy_lat (np.ndarray): Array of buoy latitudes snapped to the nearest polar grid points.
        buoy_lon (np.ndarray): Array of buoy longitudes snapped to the nearest polar grid points
        buoy_time (np.ndarray): Array of buoy observation times.
        buoy_wind_speed (np.ndarray): Array of buoy wind speeds.
        buoy_name (np.ndarray): Array of buoy names.
    """
    buoy_lat = []
    buoy_lon = []
    buoy_wind_speed = []
    buoy_time = []
    buoy_name = []

    for file in os.listdir(folder_path):
        if ".cdf" in file:
            file_path = os.path.join(folder_path, file)
            opened = xr.open_dataset(file_path)
            lat, lon, time, wind_speed, name = form_arrays_buoy(opened, date)
            if np.sum(wind_speed) > 0:
                buoy_lat.extend(lat)
                buoy_lon.extend(lon)
                buoy_wind_speed.extend(wind_speed)
                buoy_time.append(time)
                buoy_name.append(name)

    buoy_lat = np.array(buoy_lat)
    buoy_lat = snap_to_nearest(buoy_lat, polar_data.latitude.values, cutoff=0.8)
    buoy_lon = np.array(buoy_lon)
    buoy_wind_speed = np.array(buoy_wind_speed)
    buoy_time = np.array(buoy_time)

    buoy_lon = np.where(buoy_lon > 180, buoy_lon - 360, buoy_lon)
    buoy_lon = snap_to_nearest(buoy_lon, polar_data.longitude.values, cutoff=0.8)

    return buoy_lat, buoy_lon, buoy_time, buoy_wind_speed, buoy_name

calculate_degrees(file_id)

This function calculates the latitude and longitude of the GOES ABI fixed grid projection. This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

Parameters:

Name Type Description Default
file_id Dataset

The xarray dataset containing the GOES ABI fixed grid projection variables.

required

Returns:

Name Type Description
abi_lat ndarray

The latitude of the GOES ABI fixed grid projection.

abi_lon ndarray

The longitude of the GOES ABI fixed grid projection.

Source code in windscangeo\func.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def calculate_degrees(file_id):
    """This function calculates the latitude and longitude of the GOES ABI fixed grid projection. 
    This function comes from NOAA/NESDIS/STAR. (2025). Latitude and longitude remapping of GOES-R ABI imagery using Python . Atmospheric Composition Science Team. Retrieved from https://www.star.nesdis.noaa.gov/atmospheric-composition-training/python_abi_lat_lon.php

    Args:
        file_id (xarray.Dataset): The xarray dataset containing the GOES ABI fixed grid projection variables.

    Returns:
        abi_lat (numpy.ndarray): The latitude of the GOES ABI fixed grid projection.
        abi_lon (numpy.ndarray): The longitude of the GOES ABI fixed grid projection.


    """

    # Read in GOES ABI fixed grid projection variables and constants
    x_coordinate_1d = file_id.variables["x"][:]  # E/W scanning angle in radians
    y_coordinate_1d = file_id.variables["y"][:]  # N/S elevation angle in radians
    projection_info = file_id.goes_imager_projection
    lon_origin = projection_info.longitude_of_projection_origin
    H = projection_info.perspective_point_height + projection_info.semi_major_axis
    r_eq = projection_info.semi_major_axis
    r_pol = projection_info.semi_minor_axis

    # Create 2D coordinate matrices from 1D coordinate vectors
    x_coordinate_2d, y_coordinate_2d = np.meshgrid(x_coordinate_1d, y_coordinate_1d)

    # Equations to calculate latitude and longitude
    lambda_0 = (lon_origin * np.pi) / 180.0
    a_var = np.power(np.sin(x_coordinate_2d), 2.0) + (
        np.power(np.cos(x_coordinate_2d), 2.0)
        * (
            np.power(np.cos(y_coordinate_2d), 2.0)
            + (
                ((r_eq * r_eq) / (r_pol * r_pol))
                * np.power(np.sin(y_coordinate_2d), 2.0)
            )
        )
    )
    b_var = -2.0 * H * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    c_var = (H**2.0) - (r_eq**2.0)
    r_s = (-1.0 * b_var - np.sqrt((b_var**2) - (4.0 * a_var * c_var))) / (2.0 * a_var)
    s_x = r_s * np.cos(x_coordinate_2d) * np.cos(y_coordinate_2d)
    s_y = -r_s * np.sin(x_coordinate_2d)
    s_z = r_s * np.cos(x_coordinate_2d) * np.sin(y_coordinate_2d)

    # Ignore numpy errors for sqrt of negative number; occurs for GOES-16 ABI CONUS sector data
    np.seterr(all="ignore")

    abi_lat = (180.0 / np.pi) * (
        np.arctan(
            ((r_eq * r_eq) / (r_pol * r_pol))
            * ((s_z / np.sqrt(((H - s_x) * (H - s_x)) + (s_y * s_y))))
        )
    )
    abi_lon = (lambda_0 - np.arctan(s_y / (H - s_x))) * (180.0 / np.pi)

    print("INFO : Latitude and longitude calculated")
    return abi_lat, abi_lon

create_folder(experiment_name)

Create a folder for saving results based on the experiment name.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment to create a folder for.

required

Returns:

Name Type Description
str

Path to the created folder.

Source code in windscangeo\func.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def create_folder(experiment_name):
    """
    Create a folder for saving results based on the experiment name.

    Args:
        experiment_name (str): Name of the experiment to create a folder for.
        If the folder already exists, it will not be created again.

    Returns:
        str: Path to the created folder.
    """

    path_folder = f"./results_folder/model_day_{experiment_name}"

    if not os.path.exists(path_folder):
        os.makedirs(path_folder)
        print(f"Folder created at {path_folder}")

    return path_folder

early_stopping(valid_losses, patience_epochs, patience_loss)

Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

Parameters:

Name Type Description Default
valid_losses list

List of validation losses recorded during training.

required
patience_epochs int

Number of epochs to wait before stopping if no improvement.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required

Returns:

Name Type Description
bool

True if training should stop, False otherwise.

Source code in windscangeo\impl.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def early_stopping(valid_losses, patience_epochs, patience_loss):  # From @Jing
    """
    Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

    Args:
        valid_losses (list): List of validation losses recorded during training.
        patience_epochs (int): Number of epochs to wait before stopping if no improvement.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.

    Returns:
        bool: True if training should stop, False otherwise.
    """
    if len(valid_losses) < patience_epochs:
        return False
    recent_losses = valid_losses[-patience_epochs:]

    if all(x >= recent_losses[0] for x in recent_losses):
        return True

    if max(recent_losses) - min(recent_losses) < patience_loss:
        return True
    return False

error_plot(best_val_outputs, best_val_labels, path_folder=None)

Plot a scatter plot of model outputs vs true labels for the validation dataset.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the plot. If None, the plot will not be saved.

None
Source code in windscangeo\func_ml.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def error_plot(best_val_outputs, best_val_labels, path_folder=None):
    """
    Plot a scatter plot of model outputs vs true labels for the validation dataset.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        path_folder (str, optional): Path to save the plot. If None, the plot will not be saved.
    """

    max_all = max(max(best_val_outputs), max(best_val_labels))

    plt.figure(figsize=(5, 5))
    plt.plot(best_val_labels, best_val_outputs, "o")
    plt.gca().set_aspect("equal", adjustable="box")
    plt.xlim(0, max_all)
    plt.ylim(0, max_all)
    plt.xlabel("True Labels")
    plt.ylabel("Model Output")
    plt.title("Model Output vs True Labels in test dataset")
    plt.xticks(np.arange(0, 30, 5))
    plt.yticks(np.arange(0, 30, 5))
    plt.plot(
        [min(best_val_labels), max(best_val_labels)],
        [min(best_val_labels), max(best_val_labels)],
        "r--",
    )  # y = x reference line
    if path_folder:
        plt.savefig(os.path.join(path_folder, "scatter_plot.png"))

extract_goes(observation_times, observation_lats, observation_lons, scatterometer_data_path, goes_aws_url_folder, goes_channel='C01', goes_image_size=128, verbose=True)

This function extracts GOES images for the given observation times, latitudes, and longitudes. It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
scatterometer_data_path str

The path to the scatterometer data directory.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored.

required
goes_channel str

The channel of interest. Default is "C01".

'C01'
goes_image_size int

The size of the output images. Default is 128.

128
verbose bool

If True, prints progress information.

True

Returns:

Name Type Description
images ndarray

A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

Source code in windscangeo\func.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def extract_goes(
    observation_times,
    observation_lats,
    observation_lons,
    scatterometer_data_path,
    goes_aws_url_folder,
    goes_channel="C01",
    goes_image_size=128,
    verbose=True,
):
    """
    This function extracts GOES images for the given observation times, latitudes, and longitudes.
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size.

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data. 
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        scatterometer_data_path (str): The path to the scatterometer data directory.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored.
        goes_channel (str): The channel of interest. Default is "C01".
        goes_image_size (int): The size of the output images. Default is 128.
        verbose (bool): If True, prints progress information.

    Returns:
        images (numpy.ndarray): A 4D numpy array of shape (num_observations, num_channels, goes_image_size, goes_image_size) containing the extracted GOES images.

    """

    for file in os.listdir(scatterometer_data_path):
        if file.endswith(".nc"):
            polar = xr.open_dataset(
                os.path.join(scatterometer_data_path, file),
                engine="h5netcdf",
                drop_variables=["DQF"],
            )
            break

        else:
            print('WARNING : No .nc file found in the scatterometer data path, please check the path')

    template_scatter = polar.isel(time=0)
    lat_grd, lon_grd = (
        template_scatter["latitude"].values,
        template_scatter["longitude"].values,
    )

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)

    values, counts = np.unique(observation_times, return_counts=True)

    all_urls = []  # getting unique URLS
    for value in values:
        urls = get_goes_url(value, goes_aws_url_folder,goes_channel)
        all_urls.append(urls)

    values_url, indices_url, counts_url = np.unique(
        all_urls, return_index=True, return_counts=True, axis=0
    )
    # Sort indices to "unsort" the URLs
    sorted_indices = sorted(range(len(indices_url)), key=lambda k: indices_url[k])
    values_url = [all_urls[indices_url[i]] for i in sorted_indices]

    # Reorder counts_url using the same sorted indices
    counts_url = [counts_url[i] for i in sorted_indices]

    compressed_urls = values_url
    compressed_counts = []
    start_idx = 0

    for size in counts_url:
        group_sum = counts[start_idx : start_idx + size].sum()
        compressed_counts.append(group_sum)
        start_idx += size

    width = goes_image_size
    height = goes_image_size

    images = np.zeros([len(observation_times), 1 , width, height], dtype=np.float32)

    total_idx = 0
    for unique_idx, unique_urls in tqdm(
        enumerate(compressed_urls),
        desc="INFO : Retrieving and processing GOES data",
        total=len(compressed_urls),
        disable=not verbose,
    ):


        for CH_idx, url_CH in enumerate(unique_urls):

            if url_CH == 0:
                images[total_idx, CH_idx] = np.zeros([width, height])
                continue

            with fs.open(url_CH, mode="rb") as f:

                ds = xr.open_dataset(
                    f, engine="h5netcdf", drop_variables=["DQF"]
                )  # this is the bottleneck

                parallel_index = index_parallel(
                    ds,
                    template_scatter,
                )
                for i in range(compressed_counts[unique_idx]):
                    images[total_idx + i, CH_idx] = get_image(
                        ds=ds,
                        parallel_index=parallel_index,
                        lat_grd=lat_grd,
                        lon_grd=lon_grd,
                        lat_search=observation_lats[total_idx + i],
                        lon_search=observation_lons[total_idx + i],
                        goes_image_size=goes_image_size,
                    )

        total_idx += compressed_counts[unique_idx]

    if verbose:
        print(
            f"INFO : Extracted {len(observation_times)} images from {len(compressed_urls)} GOES files."
        )
    return images

extract_goes_inference(date_time, parallel_index, channels='C01', goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF')

This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.) It retrieves the GOES data from the specified AWS S3 bucket and processes it to create images of the specified size (128x128).

Parameters:

Name Type Description Default
date_time datetime64

The time of the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
channels str or list

The channel(s) of interest. Default is "C01".

'C01'
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

'noaa-goes16/ABI-L2-CMIPF'

Returns:

Name Type Description
images list

A list of numpy arrays containing the extracted GOES images of shape (128, 128).

Source code in windscangeo\func.py
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def extract_goes_inference(date_time, parallel_index,channels="C01",goes_aws_url_folder= 'noaa-goes16/ABI-L2-CMIPF'):
    """
    This function extracts GOES images for a given date_time and parallel_index. (whole GOES slice, used for inference which differs from images used in training that have a matched orbit with scatterometers.)
    It retrieves the GOES data from the specified AWS S3 bucket and processes it to create
    images of the specified size (128x128).

    Args:
        date_time (numpy.datetime64): The time of the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        channels (str or list): The channel(s) of interest. Default is "C01".
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES data is stored. Default is 'noaa-goes16/ABI-L2-CMIPF'.

    Returns:
        images (list): A list of numpy arrays containing the extracted GOES images of shape (128, 128).
    """

    # ignore divide by zero errors which occur when the GOES data can't form a 128x128 image
    np.seterr(invalid='ignore', divide='ignore')

    fs = fsspec.filesystem("s3", anon=True, default_block_size=512 * 1024**1024)
    urls = get_goes_url(date_time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel= channels)
    with fs.open(urls[0], mode="rb") as f:
        print("INFO : Reading file:", urls[0])
        goes_image = xr.open_dataset(f)
        goes_image = goes_image.rename({"x": "x_index", "y": "y_index"})

        # Assign the index coordinates (if not already done)
        goes_image = goes_image.assign_coords(
            x_index=np.arange(goes_image.sizes["x_index"]),
            y_index=np.arange(goes_image.sizes["y_index"]),
        )

        images = []
        goes_image.load()
        print("INFO : Extracting images")
        for i in range(parallel_index.shape[0]):
            for j in range(parallel_index.shape[1]):
                try:
                    x_mean = parallel_index[i][j][1].mean().astype(int)
                    x_min = x_mean - 63
                    x_max = x_mean + 63
                    y_mean = parallel_index[i][j][0].mean().astype(int)
                    y_min = y_mean - 63
                    y_max = y_mean + 63
                    image = goes_image.CMI.sel(
                        x_index=slice(x_min, x_max), y_index=slice(y_min, y_max)
                    )

                    target_size = (128, 128)

                    padded_image = np.pad(
                        image,
                        (
                            (
                                (target_size[0] - image.shape[0]) // 2,
                                (target_size[0] - image.shape[0] + 1) // 2,
                            ),
                            (
                                (target_size[1] - image.shape[1]) // 2,
                                (target_size[1] - image.shape[1] + 1) // 2,
                            ),
                        ),
                        constant_values=0,
                    )

                except:
                    images.append(np.zeros((128, 128)))

                    continue
                images.append(padded_image)

        return images

extract_goes_production(time_choice, polar_data, parallel_index, channels, goes_aws_url_folder)

Extracts GOES data for a specific time from the polar data and returns the images along with valid latitudes, longitudes, and times.

Parameters:

Name Type Description Default
time_choice str

The time for which to extract GOES data, in 'YYYY-MM-DD HH:MM:SS' format.

required
polar_data Dataset

Polar data containing latitude and longitude information. Used to create a grid of valid latitudes and longitudes.

required
parallel_index int

Index for parallel processing, used to identify the specific GOES data to extract, generated by the index_parallel function.

required
channels list

List of GOES channels to extract.

required
goes_aws_url_folder str

AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".

required

Returns:

Name Type Description
images ndarray

Array of extracted GOES images for the specified time.

valid_lats ndarray

Array of valid latitudes corresponding to the GOES images

valid_lons ndarray

Array of valid longitudes corresponding to the GOES images

valid_times ndarray

Array of valid times corresponding to the GOES images.

Source code in windscangeo\func_inference.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def extract_goes_production(time_choice, polar_data, parallel_index,channels,goes_aws_url_folder):

    """ 
    Extracts GOES data for a specific time from the polar data and returns the images along with valid latitudes, longitudes, and times.

    Args:
        time_choice (str): The time for which to extract GOES data, in 'YYYY-MM-DD HH:MM:SS' format.
        polar_data (xarray.Dataset): Polar data containing latitude and longitude information. Used to create a grid of valid latitudes and longitudes.
        parallel_index (int): Index for parallel processing, used to identify the specific GOES data to extract, generated by the `index_parallel` function.
        channels (list): List of GOES channels to extract.
        goes_aws_url_folder (str): AWS URL folder for the GOES data, default is "noaa-goes16/ABI-L2-CMIPF".

    Returns:
        images (np.ndarray): Array of extracted GOES images for the specified time.
        valid_lats (np.ndarray): Array of valid latitudes corresponding to the GOES images
        valid_lons (np.ndarray): Array of valid longitudes corresponding to the GOES images
        valid_times (np.ndarray): Array of valid times corresponding to the GOES images.

    """
    time_formated = (
        np.datetime64(time_choice).astype("datetime64[ns]").astype("float64")
    )

    longrid, latgrid = np.meshgrid(polar_data["longitude"], polar_data["latitude"])
    lon_array = longrid.flatten()
    lat_array = latgrid.flatten()
    time_array = np.full_like(lon_array, time_formated)

    valid_lons = lon_array
    valid_lats = lat_array
    valid_times = time_array

    print('INFO : Extracting GOES data')
    images = extract_goes_inference(np.datetime64(time_choice), parallel_index,channels,goes_aws_url_folder)


    images = np.expand_dims(images, axis=1)


    return images, valid_lats, valid_lons, valid_times

extract_scatter(polar_data, date, lat_range, lon_range, verbose=True, main_variable='wind_speed')

This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range. The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

Parameters:

Name Type Description Default
polar_data Dataset

The scatterometer dataset (ASCAT, HYSCAT etc).

required
date datetime64

The time of the scatterometer data.

required
lat_range tuple

The latitude range of the scatterometer data.

required
lon_range tuple

The longitude range of the scatterometer data.

required
verbose bool

If True, the function will print the progress of the extraction.

True
main_variable str

The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

'wind_speed'

Returns:

Name Type Description
observation_times ndarray

The time of observation of the scatterometer data.

observation_lats ndarray

The latitude of the scatterometer data.

observation_lons ndarray

The longitude of the scatterometer data.

observation_main_parameter ndarray

main parameter extracted (wind_speed).

Source code in windscangeo\func.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def extract_scatter(
    polar_data,
    date,
    lat_range,
    lon_range,
    verbose=True,
    main_variable="wind_speed",
):
    """
    This function extracts the scatterometer data from the polar_data dataset for the given time range, latitude range and longitude range.
    The function then saves the data into 4 numpy files : time of observation, latitude, longitude and main variable.

    Args:
        polar_data (xarray.Dataset): The scatterometer dataset (ASCAT, HYSCAT etc).
        date (numpy.datetime64): The time of the scatterometer data.
        lat_range (tuple): The latitude range of the scatterometer data.
        lon_range (tuple): The longitude range of the scatterometer data.
        verbose (bool): If True, the function will print the progress of the extraction.
        main_variable (str): The main variable to be extracted from the scatterometer data. This can be wind speed, wind direction, classification etc.

    Returns:
        observation_times (numpy.ndarray): The time of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitude of the scatterometer data.
        observation_lons (numpy.ndarray): The longitude of the scatterometer data.
        observation_main_parameter (numpy.ndarray): main parameter extracted (wind_speed).

    """

    polar = polar_data.sel(
        time=slice(date, date),
        latitude=slice(lat_range[0], lat_range[1]),
        longitude=slice(lon_range[0], lon_range[1]),
    )

    seperated_scatter = savedataseperated(polar, polar[main_variable],verbose=verbose)

    observation_times = seperated_scatter[2]
    observation_lats = seperated_scatter[0]
    observation_lons = seperated_scatter[1]
    observation_wind_speeds = seperated_scatter[3]



    return (
        observation_times,
        observation_lats,
        observation_lons,
        observation_wind_speeds,
    )

extract_scatter_multisat(scatterometer_data_path, date, lat_range, lon_range, verbose=True)

Extracts scatterometer data from multiple files (.nc) in a specified directory.

Parameters:

Name Type Description Default
scatterometer_data_path str

Path to the directory containing scatterometer data files.

required
date datetime

Date for which to extract data.

required
lat_range tuple

Latitude range (min, max) for filtering data.

required
lon_range tuple

Longitude range (min, max) for filtering data.

required
verbose bool

If True, prints progress information.

True

Returns:

Name Type Description
tuple

A tuple containing: - list of datetime: observation times - list of float: latitudes - list of float: longitudes - list of float: wind speeds

Source code in windscangeo\func_inference.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def extract_scatter_multisat(
    scatterometer_data_path, date, lat_range, lon_range,verbose=True
):

    """
    Extracts scatterometer data from multiple files (`.nc`) in a specified directory.

    Args:
        scatterometer_data_path (str): Path to the directory containing scatterometer data files.
        date (datetime): Date for which to extract data.
        lat_range (tuple): Latitude range (min, max) for filtering data.
        lon_range (tuple): Longitude range (min, max) for filtering data.
        verbose (bool): If True, prints progress information.

    Returns:
        tuple: A tuple containing:
            - list of datetime: observation times
            - list of float: latitudes
            - list of float: longitudes
            - list of float: wind speeds
    """

    observation_times = []
    observation_lats = []
    observation_lons = []
    observation_wind_speeds = []

    if verbose:
        print("INFO : Extracting scatterometer data from folder : ", scatterometer_data_path)
        print("___")

    for file in os.listdir(scatterometer_data_path):
        if ".nc" in file:
            # Open the file
            file_path = scatterometer_data_path + file
            polar_data = xr.open_dataset(file_path)
            (
                observation_times_local,
                observation_lats_local,
                observation_lons_local,
                observation_wind_speeds_local,
            ) = extract_scatter(
                polar_data, date, lat_range, lon_range, verbose=verbose
            )
            observation_times.extend(observation_times_local)
            observation_lats.extend(observation_lats_local)
            observation_lons.extend(observation_lons_local)
            observation_wind_speeds.extend(observation_wind_speeds_local)

    if verbose : 
        print("___")
        print(f"INFO : Total number of scatterometer data points: {len(observation_times)}")
    return (
        observation_times,
        observation_lats,
        observation_lons,
        observation_wind_speeds,
    )

fill_nans(images)

This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required

Returns:

Name Type Description
images ndarray

A 4D numpy array with NaN values replaced by zeros.

Source code in windscangeo\func.py
808
809
810
811
812
813
814
815
816
817
818
819
820
def fill_nans(images):
    """
    This function fills NaN values in the images with zeros. (This is simply np.nan_to_num)

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

    Returns:
        images (numpy.ndarray): A 4D numpy array with NaN values replaced by zeros.
    """
    images = np.nan_to_num(images, nan=0.0)
    print("INFO : Filled nans")
    return images

filter_invalid(images, numerical_data, min_nonzero_pixels=50)

This function filters out invalid images and corresponding numerical data based on two criteria: 1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty). 2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

Parameters:

Name Type Description Default
images ndarray

A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.

required
numerical_data dict

A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.

required
min_nonzero_pixels int

The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

50

Returns:

Name Type Description
filtered_images ndarray

A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.

filtered_numerical_data dict

A dictionary containing the numerical data associated with the valid images.

Source code in windscangeo\func.py
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
def filter_invalid(
    images,
    numerical_data,
    min_nonzero_pixels=50,
):

    """
    This function filters out invalid images and corresponding numerical data based on two criteria:
    1) The sum of pixel values in the image is not zero (i.e., the image is not completely empty).
    2) The number of non-zero pixels in the image is greater than or equal to a specified minimum threshold (default is 50).

    Args:
        images (numpy.ndarray): A 4D numpy array of shape (num_images, num_channels, height, width) containing the GOES images.
        numerical_data (dict): A dictionary containing numerical data associated with the images. The keys should match the dimensions of the images.
        min_nonzero_pixels (int): The minimum number of non-zero pixels required for an image to be considered valid. Default is 50.

    Returns:
        filtered_images (numpy.ndarray): A 4D numpy array of shape (num_valid_images, num_channels, height, width) containing the filtered GOES images.
        filtered_numerical_data (dict): A dictionary containing the numerical data associated with the valid images.

    """
    # Sums of pixel values in each image
    sums_images = [np.nansum(x) for x in images]

    # Counts of non-zero pixels in each image
    nonzero_counts = [np.count_nonzero(x) for x in images]

    # Build a "mask_invalid" array of indices that fail any criterion:
    # 1) sum == 0 (completely empty)
    # 2) nonzero pixel count < min_nonzero_pixels (not enough data)

    mask_valid = np.where(
        (np.array(sums_images) != 0) & (np.array(nonzero_counts) >= min_nonzero_pixels)
    )[0]

    # Delete the invalid entries from each array
    filtered_numerical_data = {
        key: value[mask_valid] for key, value in numerical_data.items()
    }
    filtered_images = images[mask_valid]
    n_removed_images = len(images) - len(filtered_images)

    print(
        "INFO : Filtered invalid images. Removed {} entries.".format(
            n_removed_images
        )
    )
    return (
        filtered_images,
        filtered_numerical_data,
    )

filter_nighttime(observation_times, observation_lats, observation_lons, observation_wind_speeds, min_hour=10, max_hour=19, verbose=True)

This function filters the scatterometer data to only include observations that were made during daylight hours. The function checks the hour of each observation time and only keeps those that fall within the specified range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

Parameters:

Name Type Description Default
observation_times ndarray

The times of observation of the scatterometer data.

required
observation_lats ndarray

The latitudes of the scatterometer data.

required
observation_lons ndarray

The longitudes of the scatterometer data.

required
observation_wind_speeds ndarray

The wind speeds of the scatterometer data.

required
min_hour int

The minimum hour of the day to include (default is 10).

10
max_hour int

The maximum hour of the day to include (default is 19).

19
verbose bool

If True, prints the number of valid scatterometer data points at daylight.

True

Returns:

Name Type Description
valid_times list

A list of valid observation times that fall within the specified hour range.

valid_lats list

A list of valid latitudes corresponding to the valid observation times

valid_lons list

A list of valid longitudes corresponding to the valid observation times.

valid_wind_speeds list

A list of valid wind speeds corresponding to the valid observation times.

Source code in windscangeo\func.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def filter_nighttime(
    observation_times,
    observation_lats,
    observation_lons,
    observation_wind_speeds,
    min_hour=10,
    max_hour=19,
    verbose=True,
):
    """
    This function filters the scatterometer data to only include observations that were made during daylight hours.
    The function checks the hour of each observation time and only keeps those that fall within the specified
    range (default is 10 to 19, which corresponds to 10 AM to 7 PM UTC).

    Args:
        observation_times (numpy.ndarray): The times of observation of the scatterometer data.
        observation_lats (numpy.ndarray): The latitudes of the scatterometer data.
        observation_lons (numpy.ndarray): The longitudes of the scatterometer data.
        observation_wind_speeds (numpy.ndarray): The wind speeds of the scatterometer data.
        min_hour (int): The minimum hour of the day to include (default is 10).
        max_hour (int): The maximum hour of the day to include (default is 19).
        verbose (bool): If True, prints the number of valid scatterometer data points at daylight.

    Returns:
        valid_times (list): A list of valid observation times that fall within the specified hour range.
        valid_lats (list): A list of valid latitudes corresponding to the valid observation times
        valid_lons (list): A list of valid longitudes corresponding to the valid observation times.
        valid_wind_speeds (list): A list of valid wind speeds corresponding to the valid observation times.

    """

    valid_times = []
    valid_lats = []
    valid_lons = []
    valid_wind_speeds = []

    for idx in range(len(observation_times)):
        only_hour = int(
            observation_times[idx].astype("datetime64[ns]").astype("str")[11:13]
        )
        if min_hour <= only_hour <= max_hour:
            valid_times.append(observation_times[idx])
            valid_lats.append(observation_lats[idx])
            valid_lons.append(observation_lons[idx])
            valid_wind_speeds.append(observation_wind_speeds[idx])

    if verbose:
        print(f"INFO : Total number of scatterometer data points at daylight : {len(valid_times)}")
    return valid_times, valid_lats, valid_lons, valid_wind_speeds

form_arrays_buoy(buoy, date_choice)

Form arrays from buoy data for a specific date.

Parameters:

Name Type Description Default
buoy Dataset

Buoy data containing time, latitude, longitude, and wind speed.

required
date_choice str

Date for which to extract buoy data, in 'YYYY-MM

required

Returns:

Name Type Description
lat ndarray

Array of buoy latitudes.

lon ndarray

Array of buoy longitudes.

time ndarray

Array of buoy observation times in nanoseconds since epoch.

wind_speed ndarray

Array of buoy wind speeds.

buoy_name ndarray

Array of buoy names.

Source code in windscangeo\func_inference.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def form_arrays_buoy(buoy, date_choice):

    """
    Form arrays from buoy data for a specific date.

    Args:
        buoy (xarray.Dataset): Buoy data containing time, latitude, longitude, and wind speed.
        date_choice (str): Date for which to extract buoy data, in 'YYYY-MM

    Returns:
        lat (np.ndarray): Array of buoy latitudes.
        lon (np.ndarray): Array of buoy longitudes.
        time (np.ndarray): Array of buoy observation times in nanoseconds since epoch.
        wind_speed (np.ndarray): Array of buoy wind speeds.
        buoy_name (np.ndarray): Array of buoy names.

    """
    try:
        start = np.datetime64(date_choice) - np.timedelta64(5, "m")
        end = np.datetime64(date_choice) + np.timedelta64(5, "m")

        wind_speed = buoy.sel(time=slice(start, end)).WS_401.values.flatten()
        time = (
            buoy.sel(time=slice(start, end))
            .time.values.astype("datetime64[ns]")
            .astype("int64")
        )
        lat = np.full(len(wind_speed), buoy.lat.values)
        lon = np.full(len(wind_speed), buoy.lon.values)
        buoy_name = np.full(len(wind_speed), buoy.platform_code, dtype=object)
    except:
        wind_speed = np.array([])
        time = np.array([])
        lat = np.array([])
        lon = np.array([])
        buoy_name = np.array([])

        print("date selection unavailable")
    return lat, lon, time, wind_speed, buoy_name

get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel='C01')

This function gets the nearest GOES-16 files from the time given. The function returns a list of urls to the files. The function uses the s3fs library to access the AWS GOES-16 data.

Parameters:

Name Type Description Default
time datetime[ns]

The time of the scatterometer data.

required
goes_aws_url_folder str

The folder in the AWS S3 bucket where the GOES-16 data is stored.

'noaa-goes16/ABI-L2-CMIPF'
goes_channel list

The channel of interest.

'C01'

Returns:

Name Type Description
urls list

A list of urls to the GOES-16 files.

Source code in windscangeo\func.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def get_goes_url(time, goes_aws_url_folder='noaa-goes16/ABI-L2-CMIPF', goes_channel="C01"):
    """
    This function gets the nearest GOES-16 files from the time given.
    The function returns a list of urls to the files.
    The function uses the s3fs library to access the AWS GOES-16 data.

    Args:
        time (numpy.datetime[ns]): The time of the scatterometer data.
        goes_aws_url_folder (str): The folder in the AWS S3 bucket where the GOES-16 data is stored.
        goes_channel (list): The channel of interest.

    Returns:
        urls (list): A list of urls to the GOES-16 files.


    """
    date_c = time.astype("datetime64[ns]")
    date = pd.to_datetime(date_c)
    date_str = date.strftime("%Y/%j/%H")
    min = int(date.strftime("%M"))
    min_range = [(min + i) % 60 for i in range(-6, 7)]
    min_range_str = [f"{x:02d}" for x in min_range]
    fs = s3fs.S3FileSystem(anon=True)
    # get the nearest goes file from time

    urls = []
    channel = goes_channel
    path = f"{goes_aws_url_folder}/{date_str}"
    files = fs.ls(path)
    filter_channel = [x for x in files if channel in x]
    if len(filter_channel) == 0:
        print(f"INFO :No file found for {channel} on day {date_str}, skipping file")
        return
    file = [x for x in filter_channel if x[73:75] in min_range_str]
    if len(file) == 0:
        print(
            f"INFO :No file found for {channel} on day {date_str} for minute {min}, skipping file"
        )
        return np.zeros(len(goes_channel))
    urls.append("s3://" + file[0])

    return urls

get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search, goes_image_size=128)

This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 .nc file.

Parameters:

Name Type Description Default
ds Dataset

The xarray dataset containing the GOES data.

required
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required
goes_image_size int

The size of the output image. Default is 128.

128

Returns:

Name Type Description
padded_image DataArray

A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

Source code in windscangeo\func.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def get_image(ds, parallel_index, lat_grd, lon_grd, lat_search, lon_search,goes_image_size=128):

    """
    This function retrieves a trainable GOES image for a given latitude and longitude from a GOES16 `.nc` file.

    Args:
        ds (xarray.Dataset): The xarray dataset containing the GOES data.
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.
        goes_image_size (int): The size of the output image. Default is 128.

    Returns:
        padded_image (xarray.DataArray): A padded xarray DataArray containing the GOES image centered around the specified lat/lon.

    """
    index_row = np.where(
        lat_grd == lat_search,
    )
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    pixels_from_center = (goes_image_size-1) // 2
    mean_row = rows_goes.mean().astype(int)
    min_row = mean_row - pixels_from_center
    max_row = mean_row + pixels_from_center

    mean_col = columns_goes.mean().astype(int)
    min_col = mean_col - pixels_from_center
    max_col = mean_col + pixels_from_center

    if "CMI" in ds: # If using GOES-16 L2 processed data
        image = ds.CMI[min_row:max_row, min_col:max_col].values

    elif "Rad" in ds: #If using GOES-16 L1b data
        image = ds.Rad[min_row:max_row, min_col:max_col].values

    # debug
    # print(min_row,'= min_row', max_row,'= max_row', min_col, '= min_col', max_col, '= max_col')
    target_size = (goes_image_size, goes_image_size)

    padded_image = np.pad(
        image,
        (
            (
                (target_size[0] - image.shape[0]) // 2,
                (target_size[0] - image.shape[0] + 1) // 2,
            ),
            (
                (target_size[1] - image.shape[1]) // 2,
                (target_size[1] - image.shape[1] + 1) // 2,
            ),
        ),
        constant_values=0,
    )

    padded_image = xr.DataArray(padded_image, dims=("x", "y"))
    return padded_image

get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125)

Finds the corresponding GOES row and column indices for each scatterometer point using a BallTree for efficiency, and then filtering points to form a square bounding box.

Parameters:

Name Type Description Default
lat_grid ndarray

2D array of latitudes from the scatterometer data.

required
lon_grid ndarray

2D array of longitudes from the scatterometer data.

required
Goeslat ndarray

2D array of latitudes from the GOES data.

required
Goeslon ndarray

2D array of longitudes from the GOES data.

required
radius float

Radius in degrees to define the bounding box around each scatterometer point.

0.125

Returns: indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.

Source code in windscangeo\func.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_indices(lat_grid, lon_grid, Goeslat, Goeslon, radius=0.125):
    """
    Finds the corresponding GOES row and column indices for each scatterometer point
    using a BallTree for efficiency, and then filtering points to form a square bounding box.

    Args:
        lat_grid (numpy.ndarray): 2D array of latitudes from the scatterometer data.
        lon_grid (numpy.ndarray): 2D array of longitudes from the scatterometer data.
        Goeslat (numpy.ndarray): 2D array of latitudes from the GOES data.
        Goeslon (numpy.ndarray): 2D array of longitudes from the GOES data.
        radius (float): Radius in degrees to define the bounding box around each scatterometer point.
    Returns:
        indices_array (numpy.ndarray): 2D array of tuples, where each tuple contains the row and column indices of the corresponding GOES pixel for each scatterometer point.


    """

    print("INFO : Calculating indices")
    # Flatten GOES data
    Goeslat_flat = Goeslat.flatten()
    Goeslon_flat = Goeslon.flatten()
    goes_points = np.column_stack((Goeslat_flat, Goeslon_flat))

    # Build BallTree with haversine distance
    goes_points_rad = np.radians(goes_points)
    goes_tree = BallTree(goes_points_rad, metric="haversine")

    # Flatten scatter grids
    lat_flat = lat_grid.flatten()
    lon_flat = lon_grid.flatten()
    scatter_points = np.column_stack((lat_flat, lon_flat))
    scatter_points_rad = np.radians(scatter_points)

    # Radius for broad-phase query: diagonal of the bounding box
    # Square box ±radius: diagonal = radius * sqrt(2)
    diag_radius = radius * np.sqrt(2)
    diag_radius_rad = np.radians(diag_radius)

    indices_array = np.empty(lat_flat.shape, dtype=object)
    goes_shape = Goeslat.shape

    for i, (lat_val, lon_val) in enumerate(zip(lat_flat, lon_flat)):
        # Broad-phase: query all points within diagonal distance
        candidate_indices = goes_tree.query_radius(
            np.array([scatter_points_rad[i]]), r=diag_radius_rad
        )[0]

        if candidate_indices.size == 0:
            # No points found, store empty
            indices_array[i] = (np.array([], dtype=int), np.array([], dtype=int))
            continue

        # Post-filter candidates to keep only those in the bounding box
        lat_min = lat_val - radius
        lat_max = lat_val + radius
        lon_min = lon_val - radius
        lon_max = lon_val + radius

        cand_lats = Goeslat_flat[candidate_indices]
        cand_lons = Goeslon_flat[candidate_indices]

        mask = (
            (cand_lats >= lat_min)
            & (cand_lats <= lat_max)
            & (cand_lons >= lon_min)
            & (cand_lons <= lon_max)
        )

        final_indices = candidate_indices[mask]

        # Convert these flat indices back to row,col
        rows, cols = np.unravel_index(final_indices, goes_shape)
        indices_array[i] = (rows, cols)

    # Reshape indices_array to the original shape
    indices_array = indices_array.reshape(lat_grid.shape)
    return indices_array

get_mlp(in_features, hidden_units, out_features)

Returns a MLP head

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_mlp(in_features, hidden_units, out_features):
    """
    Returns a MLP head

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    dims = [in_features] + hidden_units + [out_features]
    layers = []
    for dim1, dim2 in zip(dims[:-2], dims[1:-1]):
        layers.append(nn.Linear(dim1, dim2))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    return nn.Sequential(*layers)

goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search)

This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

Parameters:

Name Type Description Default
parallel_index ndarray

The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.

required
lat_grd ndarray

The latitude grid of the scatterometer data.

required
lon_grd ndarray

The longitude grid of the scatterometer data.

required
lat_search float

The latitude to search for in the GOES data.

required
lon_search float

The longitude to search for in the GOES data.

required

Returns:

Name Type Description
min_row int

The minimum row index of the GOES image.

max_row int

The maximum row index of the GOES image.

min_col int

The minimum column index of the GOES image.

max_col int

The maximum column index of the GOES image.

Source code in windscangeo\func.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
def goes_index(parallel_index, lat_grd, lon_grd, lat_search, lon_search):
    """
    This function retrieves the indices of the GOES image corresponding to a given latitude and longitude. This is an archived function. Current implementation decides on extent based on chosen image size.

    Args:
        parallel_index (numpy.ndarray): The precomputed indices for GOES pixels corresponding to scatterometer lat/lon.
        lat_grd (numpy.ndarray): The latitude grid of the scatterometer data.
        lon_grd (numpy.ndarray): The longitude grid of the scatterometer data.
        lat_search (float): The latitude to search for in the GOES data.
        lon_search (float): The longitude to search for in the GOES data.

    Returns:
        min_row (int): The minimum row index of the GOES image.
        max_row (int): The maximum row index of the GOES image.
        min_col (int): The minimum column index of the GOES image.
        max_col (int): The maximum column index of the GOES image.
    """

    index_row = np.where(lat_grd == lat_search)
    index_column = np.where(lon_grd == lon_search)

    rows_goes = parallel_index[index_row[0][0], index_column[0][0]][0]
    columns_goes = parallel_index[index_row[0][0], index_column[0][0]][1]

    if rows_goes.size == 0 or columns_goes.size == 0:
        return None

    min_row = rows_goes.min()
    max_row = rows_goes.max()

    min_col = columns_goes.min()
    max_col = columns_goes.max()

    return min_row, max_row, min_col, max_col

index_parallel(ds, ScatterDataset)

Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

Parameters:

Name Type Description Default
ScatterDataset

xarray Dataset containing scatterometer data.

required
scatter_name

Name for the output file.

required
output_path

Path to save the output file.

required

Returns:

Name Type Description
parallel_indice_values

2D array of tuples containing GOES row and column indices corresponding to scatterometer data.

Source code in windscangeo\func.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def index_parallel(ds, ScatterDataset):
    """
    Finds the corresponding GOES row and column indices for the entire scatterometer dataset.

    Args:
        ScatterDataset: xarray Dataset containing scatterometer data.
        scatter_name: Name for the output file.
        output_path: Path to save the output file.

    Returns:
        parallel_indice_values: 2D array of tuples containing GOES row and column indices corresponding to scatterometer data.
    """

    create_folder("satellite_indices")
    ds_spatial_resolution = ds.spatial_resolution
    ds_spatial_resolution.replace(" ", "_")

    name_str = f"lat_{ScatterDataset.latitude.min().values}_{ScatterDataset.latitude.max().values}_lon_{ScatterDataset.longitude.min().values}_{ScatterDataset.longitude.max().values}_res_{ds_spatial_resolution}"
    name_str = name_str.replace(".", "_")
    if os.path.exists(
        f"./satellite_indices/{ds_spatial_resolution}_index.npy"
    ):
        parallel_index = np.load(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            allow_pickle=True,
        )

        return parallel_index

    else:
        print(
            "INFO : Satellite index file not found, creating new index file. This might take a while."
        )

        # Extract scatterometer latitudes and longitudes
        Latitudes_Scatter = ScatterDataset["latitude"].values
        Longitudes_Scatter = ScatterDataset["longitude"].values

        # Create a meshgrid of scatterometer coordinates
        lon_grid, lat_grid = np.meshgrid(Longitudes_Scatter, Latitudes_Scatter)

        # Extract GOES latitudes and longitudes
        Goeslat, Goeslon = calculate_degrees(ds)
        Goeslat[np.isnan(Goeslat)] = 999
        Goeslon[np.isnan(Goeslon)] = 999
        # Use the optimized get_indices function
        parallel_indice_values = get_indices(lat_grid, lon_grid, Goeslat, Goeslon)

        # Save the indices array
        np.save(
            f"./satellite_indices/{ds_spatial_resolution}_index.npy",
            parallel_indice_values,
        )

        return parallel_indice_values

inference_model(model, inference_loader, device)

Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

Parameters:

Name Type Description Default
model Module

The trained model to be used for inference.

required
inference_loader DataLoader

DataLoader for the inference dataset.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
inference_outputs ndarray

Outputs from the model on the inference dataset.

Source code in windscangeo\impl.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def inference_model(model, inference_loader, device):
    """
    Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

    Args:
        model (torch.nn.Module): The trained model to be used for inference.
        inference_loader (torch.utils.data.DataLoader): DataLoader for the inference dataset.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        inference_outputs (numpy.ndarray): Outputs from the model on the inference dataset.
    """

    with torch.no_grad():  # Disable gradient calculation for inference

        inference_outputs = []

        for images in inference_loader:
            images = images.to(device)

            outputs = model(images).squeeze(-1)

            # Append outputs to the list
            inference_outputs.append(outputs)

        inference_outputs = torch.cat(inference_outputs, dim=0)
        inference_outputs = inference_outputs.cpu()
        inference_outputs = inference_outputs.numpy()

    return inference_outputs

inference_run(images, model_parameters, model_path, normalization_factors)

Runs inference on the provided images using the specified model parameters and normalization factors.

Parameters:

Name Type Description Default
images ndarray

Array of images to be used for inference.

required
model_parameters dict

Dictionary containing model parameters such as batch size, image size, channels

required
model_path str

Path to the pre-trained model file.

required
normalization_factors dict

Dictionary containing normalization factors such as mean and standard deviation.

required

Returns:

Name Type Description
inference_output ndarray

Array of inference outputs (wind speeds).

Source code in windscangeo\func_inference.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def inference_run(
    images, model_parameters, model_path, normalization_factors
):

    """
    Runs inference on the provided images using the specified model parameters and normalization factors.

    Args:
        images (np.ndarray): Array of images to be used for inference.
        model_parameters (dict): Dictionary containing model parameters such as batch size, image size, channels
        model_path (str): Path to the pre-trained model file.
        normalization_factors (dict): Dictionary containing normalization factors such as mean and standard deviation.

    Returns:
        inference_output (np.ndarray): Array of inference outputs (wind speeds).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = model_parameters["batch_size"]
    image_height = model_parameters["image_size"]
    image_width = model_parameters["image_size"]
    in_channels = model_parameters["image_channels"]
    dropout_rate = model_parameters["dropout_rate"]
    model_choice = model_parameters["model_choice"]

    mean = normalization_factors["mean"]
    std = normalization_factors["std"]
    if model_choice == 'CNN':

        features_cnn = model_parameters["features_cnn"]
        kernel_size = model_parameters["kernel_size"]
        activation_cnn = model_parameters["activation_cnn"]
        activation_final = model_parameters["activation_final"]
        stride = model_parameters["stride"]

        print("model choice is CNN")
        model = ConventionalCNN(
            image_height,
            image_width,
            features_cnn,
            kernel_size,
            in_channels,
            activation_cnn,
            activation_final,
            stride,
            dropout_rate
        ).to(device)

    if model_choice == 'ViT':
        print('model choice is ViT')
        # Load the model
        model = ViT(
            img_size = (128, 128),
            patch_size = (8,8),
            n_channels = 1,
            d_model = 1024,
            nhead = 4,
            dim_feedforward = 2048,
            blocks = 8,
            mlp_head_units = [1024, 512],
            n_classes = 1,
        ).to(device)

    if model_choice == 'ResNet':
        model = ResNet50(num_classes=1, channels=1).to(device)



    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    inference_dataset = conventional_dataset_inference(
        images,
        transform=Normalize(mean,std),
    )
    inference_loader = DataLoader(inference_dataset, batch_size, shuffle=False)

    inference_output = inference_model(model, inference_loader, device)

    return inference_output

manage_saved_models(directory)

Manage saved model files in the specified directory by deleting older epoch files. Keeps only the latest epoch file and deletes all others. From @ Jing Sun

Parameters:

Name Type Description Default
directory str

The directory where model files are saved.

required
Source code in windscangeo\impl.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def manage_saved_models(directory):  # From @Jing
    """
    Manage saved model files in the specified directory by deleting older epoch files.
    Keeps only the latest epoch file and deletes all others. From @ Jing Sun

    Args:
        directory (str): The directory where model files are saved.
    """

    pattern = re.compile(r"epoch_(\d+)\.pth")
    epoch_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            match = pattern.match(file)
            if match:
                epoch_num = int(match.group(1))
                file_path = os.path.join(root, file)
                epoch_files.append((file_path, epoch_num))

    # Check if there are more than 5 files
    if len(epoch_files) > 1:
        epoch_files.sort(key=lambda x: x[1])
        files_to_delete = len(epoch_files) - 1

        for i in range(files_to_delete):
            os.remove(epoch_files[i][0])

package_data(images, numerical_data, filter=True, solar_conversion=False, verbose=True)

This function packages the images and numerical data into a format that can be used for training a machine learning model. The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data) The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True. The function will return the images and numerical data in a numpy array format.

Parameters:

Name Type Description Default
images ndarray

The GOES images corresponding to the observation data.

required
numerical_data dict

A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".

required
filter bool

If True, the function will filter out invalid images and fill in Nan values. Default is True.

True
solar_conversion bool

If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)

False
verbose bool

If True, the function will print progress information. Default is True.

True

Returns:

Name Type Description
images ndarray

The GOES images corresponding to the observation data.

numerical_data ndarray

The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

Source code in windscangeo\func.py
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
def package_data(
    images,
    numerical_data,
    filter=True,
    solar_conversion=False,
    verbose=True
):
    """
    This function packages the images and numerical data into a format that can be used for training a machine learning model.
    The function will filter out invalid images and fill in any NaN values. (Invalid images = empty images from GOES data)
    The function will also convert the observation times, latitudes and longitudes to solar angles (sza, saa) if solar_conversion is set to True.
    The function will return the images and numerical data in a numpy array format.

    Args:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (dict): A dictionary containing the numerical data corresponding to the observation data. The keys should include "observation_lats", "observation_lons", "observation_times" and optionally "wind_speeds".
        filter (bool): If True, the function will filter out invalid images and fill in Nan values. Default is True.
        solar_conversion (bool): If True, the function will convert the observation times, latitudes and longitudes to solar angles (sza, saa). Default is False. (Not used in current implementation, but kept in case of future use)
        verbose (bool): If True, the function will print progress information. Default is True.

    Returns:
        images (numpy.ndarray): The GOES images corresponding to the observation data.
        numerical_data (numpy.ndarray): The numerical data corresponding to the observation data. (sza, saa, main_parameter if solar_conversion is set to True or lat, lon, time, wind_speeds if solar_conversion is set to False)

    """
    if filter:
        (images, numerical_data) = filter_invalid(images, numerical_data)
        images = fill_nans(images)

    if solar_conversion:
        observation_lats = numerical_data["observation_lats"]
        observation_lons = numerical_data["observation_lons"]
        observation_times = numerical_data["observation_times"]

        sza, saa = vectorized_solar_angles(
            observation_lats, observation_lons, observation_times
        )

        sza_rad = np.deg2rad(sza)
        sza_sin = np.sin(sza_rad)
        sza_cos = np.cos(sza_rad)

        saa_rad = np.deg2rad(saa)
        saa_sin = np.sin(saa_rad)
        saa_cos = np.cos(saa_rad)

        # Add the solar angles to the numerical data dictionary

        numerical_data["sza_sin"] = sza_sin
        numerical_data["sza_cos"] = sza_cos
        numerical_data["saa_sin"] = saa_sin
        numerical_data["saa_cos"] = saa_cos

        print("Data Preparation : converted to solar angles (sza, saa)")
        print("Data Preparation : returning images, numerical_data")
        return images, numerical_data

    else:

        return images, numerical_data

patchify(batch, patch_size)

Patchify the batch of images

Shape

batch: (b, h, w, c) output: (b, nh, nw, ph, pw, c)

taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch

Source code in windscangeo\Models.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def patchify(batch, patch_size):
    """
    Patchify the batch of images

    Shape:
        batch: (b, h, w, c)
        output: (b, nh, nw, ph, pw, c)

    taken from https://www.kaggle.com/code/umongsain/vision-transformer-from-scratch-pytorch
    """
    b, c, h, w = batch.shape
    ph, pw = patch_size
    nh, nw = h // ph, w // pw

    batch_patches = torch.reshape(batch, (b, c, nh, ph, nw, pw))
    batch_patches = torch.permute(batch_patches, (0, 1, 2, 4, 3, 5))

    return batch_patches

plot_cloud_cover(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be used for cloud cover calculation.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def plot_cloud_cover(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud cover mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be used for cloud cover calculation.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = np.mean(images, axis=(2,3))
    threshold = 0.11
    cloud_mask = np.where(mean_images > threshold, 1, 0)
    cloud_mask = cloud_mask.reshape(160,340)

    plot_cloud_mask(lat_inference,lon_inference,cloud_mask,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice)

    percentage_cloud = np.sum(cloud_mask)/cloud_mask.size
    print('Cloud coverage : ',percentage_cloud*100,'%')

    return cloud_mask, percentage_cloud

plot_cloud_mask(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Cloud mask data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def plot_cloud_mask(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the cloud mask on a map with buoy locations. This works well with 30x30 images but larger images can diluted. Can be adapted to work with larger images.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Cloud mask data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """

    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, wind_speeds_inference,
        shading='auto', cmap='Blues',
        vmin=0, vmax=1
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='0 = Clear, 1 = Cloudy')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)

    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))

    # Plot each buoy in a single color
    for i, buoy_id in enumerate(unique_buoys):
        # Pick a distinct color from tab20
        color = color_map(i)

        # Identify the indices belonging to this buoy
        mask = np.array(buoy_name_flat) == buoy_id

        # Scatter just those points
        ax.scatter(
            np.array(buoy_lon)[mask],
            np.array(buoy_lat)[mask],
            s=100,
            color=color,            # Set the fill color
            edgecolor='black',
            linewidth=1,
            zorder=12,
            label=buoy_id           # Use buoy_id as the legend label
        )

    # Create the legend
    leg = ax.legend(
        title="Buoy Stations",
        loc="upper right",
        bbox_to_anchor=(1.0, 1.0),
        bbox_transform=ax.transAxes
    )

    # Ensure legend is above all other layers
    leg.set_zorder(999)

    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'Cloud Mask at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())


    plt.savefig(f'{path_folder}/plot_cloud_mask.png')

plot_goes_image(lat_inference, lon_inference, images, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
images ndarray

GOES image data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def plot_goes_image(lat_inference,lon_inference,images,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the GOES image data on a map with buoy locations. Made for 128x128 images where the middpoint is at (64,64). If using other image sizes, the plotting will probably not work as expected.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        images (np.ndarray): GOES image data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    mean_images = images[:,:,64,64]
    mean_images = mean_images.ravel()
    mean_images = mean_images.reshape(160,340)


    fig = plt.figure(figsize=(20, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())


    # Plot wind speed data (continuous colormap)
    pcm = ax.pcolormesh(
        lon_inference, lat_inference, mean_images,
        shading='auto',
        vmin=0,
        vmax=1,
    )

    # Add colorbar
    cbar = fig.colorbar(pcm, label='Brightness Temperature (K) - C01')

    # Add coastlines and land
    ax.add_feature(cfeature.LAND, color='white', alpha=1, zorder=10)  
    ax.coastlines(zorder=11)


    # Flatten buoy_name if it's a list of arrays
    buoy_name_flat = np.concatenate(buoy_name).tolist()
    unique_buoys = list(set(buoy_name_flat))

    # Generate enough distinct colors for all buoys from the "tab20" palette
    # (tab20 provides 20 colors; if there are more than 20 unique buoys, colors will repeat)
    color_map = plt.cm.get_cmap("tab20", len(unique_buoys))


    # Add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
    gl.right_labels = False
    gl.top_labels = False

    ax.set_title(f'GOES image at time {time_choice}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_ylim(lat_inference.min(), lat_inference.max())
    ax.set_xlim(lon_inference.min(), lon_inference.max())

    plt.savefig(f'{path_folder}/plot_goes_image.png')

    mean_images = np.array(mean_images)

    return mean_images

plot_save_loss(best_val_outputs, best_val_labels, train_losses, val_losses, path_folder, saving=False)

Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

Parameters:

Name Type Description Default
best_val_outputs list or ndarray

Model outputs for the validation dataset.

required
best_val_labels list or ndarray

True labels for the validation dataset.

required
train_losses list

List of training losses per epoch.

required
val_losses list

List of validation losses per epoch.

required
path_folder str

Path to save the plot and optionally the outputs and labels.

required
saving bool

If True, save the best validation outputs and labels. Default is False.

False
Source code in windscangeo\func_ml.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def plot_save_loss(
    best_val_outputs,
    best_val_labels,
    train_losses,
    val_losses,
    path_folder,
    saving=False,
):
    """
    Plot and save the training and validation losses, and optionally save the best validation outputs and labels.

    Args:
        best_val_outputs (list or np.ndarray): Model outputs for the validation dataset.
        best_val_labels (list or np.ndarray): True labels for the validation dataset.
        train_losses (list): List of training losses per epoch.
        val_losses (list): List of validation losses per epoch.
        path_folder (str): Path to save the plot and optionally the outputs and labels.
        saving (bool, optional): If True, save the best validation outputs and labels. Default is False.
    """
    # After training, save only the best validation outputs and labels
    if saving:
        np.save(
            os.path.join(path_folder, "best_validation_outputs.npy"), best_val_outputs
        )
        np.save(
            os.path.join(path_folder, "best_validation_labels.npy"), best_val_labels
        )

    num_epochs = len(train_losses)
    # Plotting the training and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
    plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    text_str = f"num_epochs = {num_epochs}, train loss = {train_losses[-1]:.2f}, validation loss = {val_losses[-1]:.2f}"
    plt.text(
        0.05,
        0.05,
        text_str,
        ha="left",
        va="bottom",
        transform=plt.gca().transAxes,  # Ensures the coordinates are relative to the axes (0 to 1 range)
    )
    plt.title("Training and Validation Loss Over Epochs")
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(path_folder, "loss_plot.png"))

plot_wind_speeds(lat_inference, lon_inference, wind_speeds_inference, path_folder, buoy_name, buoy_lat, buoy_lon, time_choice)

Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

Parameters:

Name Type Description Default
lat_inference ndarray

Latitude values for the inference grid.

required
lon_inference ndarray

Longitude values for the inference grid.

required
wind_speeds_inference ndarray

Wind speed data to be plotted.

required
path_folder str

Path to save the plot.

required
buoy_name list or ndarray

Names of the buoys.

required
buoy_lat list or ndarray

Latitude values of the buoys.

required
buoy_lon list or ndarray

Longitude values of the buoys.

required
time_choice datetime

Time of the inference.

required
Source code in windscangeo\func_ml.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def plot_wind_speeds(lat_inference,lon_inference,wind_speeds_inference,path_folder,buoy_name,buoy_lat,buoy_lon,time_choice):
    """
    Plot the wind speeds on a map with buoy locations. Filter nighttime images and add coastlines, gridlines, and buoy locations.

    Args:
        lat_inference (np.ndarray): Latitude values for the inference grid.
        lon_inference (np.ndarray): Longitude values for the inference grid.
        wind_speeds_inference (np.ndarray): Wind speed data to be plotted.
        path_folder (str): Path to save the plot.
        buoy_name (list or np.ndarray): Names of the buoys.
        buoy_lat (list or np.ndarray): Latitude values of the buoys.
        buoy_lon (list or np.ndarray): Longitude values of the buoys.
        time_choice (datetime): Time of the inference.
    """
    time_str = time_choice.strftime('%Y-%m-%d %H:%M:%S')
    date, time = time_str.split(' ')
    lat = lat_inference
    lon = lon_inference
    wind_speeds = wind_speeds_inference


    buoy_names = buoy_name
    buoy_lat = buoy_lat
    buoy_lon = buoy_lon


    # nighttime mask

    lat_flat = lat.flatten()
    lon_flat = lon.flatten()
    time_flat = np.full(len(lat_flat), pd.Timestamp(f'{date} {time}'), dtype='datetime64[ns]')


    sza, saa = vectorized_solar_angles(lat_flat, lon_flat, time_flat)
    saa = np.reshape(saa,lat.shape)
    sza = np.reshape(sza,lat.shape)

    night_time_mask = np.where(sza > 90, 1, np.nan)
    cmap = ListedColormap(['white'])

    ###############

    min_lon, max_lon, min_lat, max_lat = -70, 0, -12, 20

    # add coastlines and gridlines
    fig = plt.figure(figsize=(22, 10), dpi=100)
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    levels = np.arange(0, 13, 1)
    line_colour = 'black'
    line_colours = ['black' for i in levels]
    ax.title.set_text(f'Wind Speed prediction from C01 GOES image (m/s) {date} {time} ')
    ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15,zorder = 5)
    ax.contourf(lon, lat, night_time_mask, transform=ccrs.PlateCarree(),cmap=cmap,zorder = 10)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.LAND, edgecolor='black',zorder= 20)
    ax.set_xticks(np.arange(-70, 1, 10), crs=ccrs.PlateCarree())
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0f}°E'))
    ax.set_yticks(np.arange(-15, 21, 5), crs=ccrs.PlateCarree())
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0f}°N'))
    ax.set_extent((-70, 0, -12, 20))
    ax.hlines(0, -70, 0, color='red', linewidth=1.5, linestyle='--', zorder= 10)
    fig.colorbar(ax.pcolormesh(lon, lat, wind_speeds, transform=ccrs.PlateCarree(), cmap='jet',alpha=0.6,vmin=0,vmax=15), ax=ax, orientation='vertical', aspect=50, label='Wind Speed (m/s)')
    ax.gridlines(color='gray', linestyle='--', alpha=0.5,zorder= 999)

    for buoy in range(len(buoy_names)):
        lon_b, lat_b = buoy_lon[buoy], buoy_lat[buoy]

        # Skip if out of bounds
        if not (min_lon <= lon_b <= max_lon and min_lat <= lat_b <= max_lat):
            continue

        ax.plot(lon_b, lat_b, 'o', color='red', markersize=5, transform=ccrs.PlateCarree(), label=buoy_names[buoy],zorder = 999)

        ax.text(
            lon_b + 0.2, lat_b + 0.2, buoy_names[buoy],
            fontsize=9, color='white',
            transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='black', edgecolor='none', boxstyle='square,pad=0.2'),zorder = 999
        )
    plt.savefig(f'{path_folder}/plot_wind_speeds.png')

rmse_per_range(model_output, target, path_folder)

Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

Parameters:

Name Type Description Default
model_output list or ndarray

Model outputs for the validation dataset.

required
target list or ndarray

True labels for the validation dataset.

required
path_folder str

Path to save the CSV file.

required

Returns: pd.DataFrame: DataFrame containing the RMSE and count for each range.

Source code in windscangeo\func_ml.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def rmse_per_range(model_output, target,path_folder):
    """
    Calculate the RMSE for different ranges of wind speeds and save the results to a CSV file.

    Args:
        model_output (list or np.ndarray): Model outputs for the validation dataset.
        target (list or np.ndarray): True labels for the validation dataset.
        path_folder (str): Path to save the CSV file.
    Returns:
        pd.DataFrame: DataFrame containing the RMSE and count for each range.
    """

    max_target = np.max(target)
    bins = np.arange(0, max_target, 1)
    rmse = np.zeros(len(bins))
    count = np.zeros(len(bins))
    results = []

    for i in range(len(bins)-1):
        idx = np.where((target >= bins[i]) & (target <= bins[i+1]))
        rmse[i] = np.sqrt(np.mean((model_output[idx] - target[idx])**2))
        count[i] = len(idx[0])
        print(f"EVAL : Range {bins[i]} m/s - {bins[i+1]} m/s: RMSE = {rmse[i]}, count = {int(count[i])}")
        results.append({'bin_start': bins[i], 'bin_end': bins[i+1], 'rmse': rmse[i], 'count': count[i]})

    df = pd.DataFrame(results)
    df.to_csv(f'{path_folder}/rmse_per_range.csv')
    return df

save_overpass_time(time_list, name_scatter)

This function prints the overpass time of the scatterometer.

Parameters:

Name Type Description Default
time_list ndarray

The measurement time values of the scatterometer data.

required
name_scatter str

The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

required

Returns:

Type Description

None

Source code in windscangeo\func.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def save_overpass_time(time_list,name_scatter):
    """
    This function prints the overpass time of the scatterometer.

    Args:
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        name_scatter (str): The name of the scatterometer data source (e.g. ASCAT, HYSCAT etc).

    Returns:
        None 

    """
    formated_time = time_list.astype('datetime64[ns]')
    hour_minute = formated_time.astype('datetime64[m]')
    unique_hour_minute = np.unique(hour_minute)

    filtered = [unique_hour_minute[0]]

    delta = np.timedelta64(1, 'h')

    for time in unique_hour_minute[1:]:
        if time - filtered[-1] >= delta:
            filtered.append(time)

    time_only = []
    for time in filtered:
        time = str(time).split('T')[1]
        time_only.append(time)
    print(f"ORBIT : {name_scatter} overpass time : {time_only}")

savedataseperated(ScatterData, main_parameter, verbose=True)

This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel of the scatterometer data and saves it to a numpy file.

Parameters:

Name Type Description Default
ScatterData Dataset

The ASCAT dataset containing the scatterometer data.

required
main_parameter DataArray

The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

required

Returns:

lat_list (numpy.ndarray): The latitude values of the scatterometer data.
lon_list (numpy.ndarray): The longitude values of the scatterometer data.
time_list (numpy.ndarray): The measurement time values of the scatterometer data.
main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

this function saves the data locally to a folder called data_processed_scat

Source code in windscangeo\func.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def savedataseperated(ScatterData, main_parameter,verbose=True):
    """
    This function extracts the valid lon / lat / measurement time and the main parameter from ever pixel
    of the scatterometer data and saves it to a numpy file.

    Args:
        ScatterData (xarray.Dataset): The ASCAT dataset containing the scatterometer data.
        main_parameter (xarray.DataArray): The main parameter to be saved. This can be a classification / wind speed / wind direction etc.

    Returns:

        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        main_parameter_list (numpy.ndarray): The main parameter values of the scatterometer data.

    this function saves the data locally to a folder called data_processed_scat
    """
    lat_full, lon_full, time_full = ScatterData.indexes.values()
    measurement_time_full = ScatterData.measurement_time

    lat_full = np.array(lat_full)
    lon_full = np.array(lon_full)
    measurement_time_full = np.array(measurement_time_full)
    main_parameter = np.array(main_parameter)

    index = np.argwhere(~np.isnan(main_parameter))

    index_list = []
    lat_list = []
    lon_list = []
    time_list = []
    wind_speed_list = []

    name_scatter = ScatterData.source

    for t, i, j in index:

        # print(t,'= time', i,'=row', j, '=column')
        index_list.append((t, i, j))

        # print(measurement_time_full[t, i, j].astype('datetime64[ns]'))
        time_list.append(measurement_time_full[t, i, j])

        # print(lat_full[i])
        lat_list.append(lat_full[i])

        # print(lon_full[j])
        lon_list.append(lon_full[j])

        # print(AllWindSpeeds[t, i, j])
        wind_speed_list.append(main_parameter[t, i, j])

    lat_list = np.array(lat_list)
    lon_list = np.array(lon_list)
    time_list = np.array(time_list)
    wind_speed_list = np.array(wind_speed_list)

    lat_list, lon_list, time_list, wind_speed_list = sort_by_time(
        lat_list, lon_list, time_list, wind_speed_list
    )
    if verbose:
        save_overpass_time(time_list,name_scatter)    

    return lat_list, lon_list, time_list, wind_speed_list

snap_to_nearest(values, reference_array, cutoff=1.0)

Snap an array of values to the nearest values in a reference array. If the difference is greater than the cutoff, the original value is returned.

Parameters:

Name Type Description Default
values ndarray

Array of values to snap.

required
reference_array ndarray

Array of reference values.

required
cutoff float

Maximum allowable difference for snapping.

1.0

Returns:

Type Description

np.ndarray: Snapped values.

Source code in windscangeo\func_inference.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def snap_to_nearest(values, reference_array, cutoff=1.0):
    """
    Snap an array of values to the nearest values in a reference array.
    If the difference is greater than the cutoff, the original value is returned.

    Args:
        values (np.ndarray): Array of values to snap.
        reference_array (np.ndarray): Array of reference values.
        cutoff (float): Maximum allowable difference for snapping.

    Returns:
        np.ndarray: Snapped values.
    """
    # Convert inputs to NumPy arrays for compatibility
    values = np.asarray(values)
    reference_array = np.asarray(reference_array)

    # Find the nearest reference value for each input value
    # Reshape reference_array to allow broadcasting
    reference_array = reference_array.reshape(1, -1)
    differences = np.abs(values.reshape(-1, 1) - reference_array)
    nearest_indices = np.argmin(differences, axis=1)
    nearest_values = reference_array.ravel()[nearest_indices]

    # Apply the cutoff condition
    snap_mask = np.abs(values - nearest_values) <= cutoff
    snapped_values = np.where(snap_mask, nearest_values, values)

    return snapped_values

sort_by_time(lat_list, lon_list, time_list, wind_speed_list)

This function sorts the output of savedataseperated() by time. This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

Parameters:

Name Type Description Default
lat_list ndarray

The latitude values of the scatterometer data.

required
lon_list ndarray

The longitude values of the scatterometer data.

required
time_list ndarray

The measurement time values of the scatterometer data.

required
wind_speed_list ndarray

The wind speed values of the scatterometer data.

required

Returns:

Name Type Description
lat_list_sorted ndarray

The sorted latitude values of the scatterometer data.

lon_list_sorted ndarray

The sorted longitude values of the scatterometer data.

time_list_sorted ndarray

The sorted measurement time values of the scatterometer data.

wind_speed_list_sorted ndarray

The sorted wind speed values of the scatterometer data.

Source code in windscangeo\func.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def sort_by_time(lat_list, lon_list, time_list, wind_speed_list):
    """
    This function sorts the output of savedataseperated() by time.
    This allows for more efficient data processing and allows file caching for times that are represented by the same GOES file.

    Args:
        lat_list (numpy.ndarray): The latitude values of the scatterometer data.
        lon_list (numpy.ndarray): The longitude values of the scatterometer data.
        time_list (numpy.ndarray): The measurement time values of the scatterometer data.
        wind_speed_list (numpy.ndarray): The wind speed values of the scatterometer data.

    Returns:
        lat_list_sorted (numpy.ndarray): The sorted latitude values of the scatterometer data.
        lon_list_sorted (numpy.ndarray): The sorted longitude values of the scatterometer data.
        time_list_sorted (numpy.ndarray): The sorted measurement time values of the scatterometer data.
        wind_speed_list_sorted (numpy.ndarray): The sorted wind speed values of the scatterometer data.

    """
    # Get the indices that would sort the measurement_time array
    sorted_indices = np.argsort(time_list)

    # Reorder the arrays using the sorted indices
    time_list_sorted = time_list[sorted_indices]
    lat_list_sorted = lat_list[sorted_indices]
    lon_list_sorted = lon_list[sorted_indices]
    speed_list_sorted = wind_speed_list[sorted_indices]

    return lat_list_sorted, lon_list_sorted, time_list_sorted, speed_list_sorted

test_model(model, test_loader, criterion, device)

Evaluate the model on the test dataset and return the outputs, targets, and average loss.

Parameters:

Name Type Description Default
model Module

The trained model to be evaluated.

required
test_loader DataLoader

DataLoader for the test dataset.

required
criterion Module

Loss function to be used for evaluation.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
test_outputs ndarray

Outputs from the model on the test dataset.

test_targets ndarray

Targets corresponding to the test outputs.

avg_test_loss float

Average loss on the test dataset.

Source code in windscangeo\impl.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def test_model(model, test_loader, criterion, device):
    """
    Evaluate the model on the test dataset and return the outputs, targets, and average loss.

    Args:
        model (torch.nn.Module): The trained model to be evaluated.
        test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
        criterion (torch.nn.Module): Loss function to be used for evaluation.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        test_outputs (numpy.ndarray): Outputs from the model on the test dataset.
        test_targets (numpy.ndarray): Targets corresponding to the test outputs.
        avg_test_loss (float): Average loss on the test dataset.
    """

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for inference

        test_outputs = []
        test_targets = []
        test_loss = 0.0

        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            # Append outputs to the list
            test_outputs.append(outputs)
            test_targets.append(targets)

        avg_test_loss = test_loss / len(test_loader)
        print(f"EVAL : Test Loss: {avg_test_loss}")

        test_outputs = torch.cat(test_outputs, dim=0)
        test_outputs = test_outputs.cpu()
        test_outputs = test_outputs.numpy()

        test_targets = torch.cat(test_targets, dim=0)
        test_targets = test_targets.cpu()
        test_targets = test_targets.numpy()

    return test_outputs, test_targets, avg_test_loss

train_model(model, train_loader, val_loader, num_epochs, lr, weight_decay, criterion, device, optimizer_choice, patience_epochs, patience_loss, path_folder)

Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
train_loader DataLoader

DataLoader for the training dataset.

required
val_loader DataLoader

DataLoader for the validation dataset.

required
num_epochs int

Number of epochs to train the model.

required
lr float

Learning rate for the optimizer.

required
weight_decay float

Weight decay for the optimizer.

required
criterion Module

Loss function to be used.

required
device device

Device to run the model on (CPU or GPU).

required
optimizer_choice str

Choice of optimizer ('Adam', 'SGD', 'RMSprop').

required
patience_epochs int

Number of epochs to wait before stopping if no improvement in validation loss.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required
path_folder str

Path to save the model checkpoints.

required

Returns:

Name Type Description
best_val_outputs ndarray

Best validation outputs from the model.

best_val_labels ndarray

Best validation labels corresponding to the outputs.

best_model Module

The best model based on validation loss.

train_losses list

List of training losses for each epoch.

val_losses list

List of validation losses for each epoch.

Source code in windscangeo\impl.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs,
    lr,
    weight_decay,
    criterion,
    device,
    optimizer_choice,
    patience_epochs,
    patience_loss,
    path_folder,
):

    """
    Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

    Args:
        model (torch.nn.Module): The model to be trained.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        num_epochs (int): Number of epochs to train the model.
        lr (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay for the optimizer.
        criterion (torch.nn.Module): Loss function to be used.
        device (torch.device): Device to run the model on (CPU or GPU).
        optimizer_choice (str): Choice of optimizer ('Adam', 'SGD', 'RMSprop').
        patience_epochs (int): Number of epochs to wait before stopping if no improvement in validation loss.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.
        path_folder (str): Path to save the model checkpoints.

    Returns:
        best_val_outputs (numpy.ndarray): Best validation outputs from the model.
        best_val_labels (numpy.ndarray): Best validation labels corresponding to the outputs.
        best_model (torch.nn.Module): The best model based on validation loss.
        train_losses (list): List of training losses for each epoch.
        val_losses (list): List of validation losses for each epoch.
    """


    if optimizer_choice == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Invalid optimizer choice. Please choose 'Adam' or 'SGD'.")

    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.8)

    # Initialize lists to store loss values and validation predictions
    train_losses = []
    val_losses = []
    best_val_loss = float("inf")
    best_val_outputs = None
    best_val_labels = None


    pbar = tqdm(range(num_epochs), desc="TRAIN : Training Progress")    
    for epoch in pbar:
        # Training Phase
        model.train()
        running_loss = 0.0

        for images, targets in train_loader:

            # Move data to GPU
            images = images.to(device)
            targets = targets.to(device)            

            # Forward pass
            optimizer.zero_grad()
            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)  # Ensure target shape matches output
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate average training loss for the epoch
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        #print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}"})

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_outputs = []  # Temporary list to store outputs for this epoch
        val_labels = []  # Temporary list to store labels for this epoch
        with torch.no_grad():
            for images, targets in val_loader:
                # Move data to GPU
                images = images.to(device)
                targets = targets.to(device)

                # Get model output
                outputs = model(images).squeeze(-1)

                # Calculate loss
                loss = criterion(outputs, targets)  # Ensure target shape matches output
                val_loss += loss.item()

                # Append outputs and targets to lists
                val_outputs.append(outputs.cpu())  # Move to CPU for concatenation
                val_labels.append(targets.cpu())

        # Concatenate outputs and labels across all batches to ensure all samples are included
        val_outputs = torch.cat(val_outputs, dim=0).numpy()
        val_labels = torch.cat(val_labels, dim=0).numpy()

        # Calculate average validation loss for the epoch
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        #print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}", "Val Loss": f"{avg_val_loss:.4f}"})

        # Check if this is the best validation loss so far and store best outputs/labels
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_outputs = val_outputs
            best_val_labels = val_labels

            best_model = model
            best_model_state = model.state_dict()
            torch.save(
                best_model_state, os.path.join(path_folder, f"./epoch_{epoch + 1}.pth")
            )


        manage_saved_models(path_folder)

        if early_stopping(val_losses, patience_epochs, patience_loss):
            return (
                best_val_outputs,
                best_val_labels,
                best_model,
                train_losses,
                val_losses,
            )
        scheduler.step()



    return best_val_outputs, best_val_labels, model, train_losses, val_losses

vectorized_solar_angles(lat, lon, time_utc)

This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

Parameters:

Name Type Description Default
lat ndarray

The latitude values of the scatterometer data.

required
lon ndarray

The longitude values of the scatterometer data.

required
time_utc ndarray

The observation times in UTC.

required

Returns:

Name Type Description
sza ndarray

The solar zenith angle in degrees.

saa ndarray

The solar azimuth angle in degrees.

Source code in windscangeo\func_ml.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def vectorized_solar_angles(lat, lon, time_utc):

    """
    This function calculates the solar zenith angle (SZA) and solar azimuth angle (SAA) for a given latitude, longitude, and time. This is an archived function. Current implementation does not use solar angles but only image input.

    Args:
        lat (numpy.ndarray): The latitude values of the scatterometer data.
        lon (numpy.ndarray): The longitude values of the scatterometer data.
        time_utc (numpy.ndarray): The observation times in UTC.

    Returns:
        sza (numpy.ndarray): The solar zenith angle in degrees.
        saa (numpy.ndarray): The solar azimuth angle in degrees.
    """

    # Convert time to Julian Day
    timestamp = pd.to_datetime(time_utc).tz_localize(None)
    jd = (
        timestamp.astype("datetime64[ns]").astype(np.int64) / 86400000000000 + 2440587.5
    )
    d = jd - 2451545.0  # Days since J2000

    # Mean longitude, mean anomaly, ecliptic longitude
    g = np.deg2rad((357.529 + 0.98560028 * d) % 360)  # Mean anomaly
    q = np.deg2rad((280.459 + 0.98564736 * d) % 360)  # Mean longitude
    L = (q + np.deg2rad(1.915) * np.sin(g) + np.deg2rad(0.020) * np.sin(2 * g)) % (
        2 * np.pi
    )  # Ecliptic long

    # Obliquity of the ecliptic
    e = np.deg2rad(23.439 - 0.00000036 * d)

    # Sun declination
    sin_delta = np.sin(e) * np.sin(L)
    delta = np.arcsin(sin_delta)

    # Equation of time (in minutes)
    E = 229.18 * (
        0.000075
        + 0.001868 * np.cos(g)
        - 0.032077 * np.sin(g)
        - 0.014615 * np.cos(2 * g)
        - 0.040849 * np.sin(2 * g)
    )

    # Convert time to fractional hours (UTC)
    fractional_hour = timestamp.hour + timestamp.minute / 60 + timestamp.second / 3600

    # Solar time correction
    time_offset = E + 4 * lon  # lon in degrees
    tst = fractional_hour * 60 + time_offset  # True Solar Time in minutes
    ha = np.deg2rad((tst / 4 - 180) % 360)  # Hour angle in radians

    # Convert lat/lon to radians
    lat_rad = np.deg2rad(lat)

    # Solar zenith angle
    cos_zenith = np.sin(lat_rad) * np.sin(delta) + np.cos(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    zenith = np.rad2deg(np.arccos(np.clip(cos_zenith, -1, 1)))  # in degrees

    # Solar saa angle
    sin_saa = -np.sin(ha) * np.cos(delta)
    cos_saa = np.cos(lat_rad) * np.sin(delta) - np.sin(lat_rad) * np.cos(
        delta
    ) * np.cos(ha)
    saa = np.rad2deg(np.arctan2(sin_saa, cos_saa))
    saa = (saa + 360) % 360  # Normalize

    return zenith, saa

early_stopping(valid_losses, patience_epochs, patience_loss)

Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

Parameters:

Name Type Description Default
valid_losses list

List of validation losses recorded during training.

required
patience_epochs int

Number of epochs to wait before stopping if no improvement.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required

Returns:

Name Type Description
bool

True if training should stop, False otherwise.

Source code in windscangeo\impl.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def early_stopping(valid_losses, patience_epochs, patience_loss):  # From @Jing
    """
    Early stopping function to determine if training should stop based on validation losses. From @ Jing Sun

    Args:
        valid_losses (list): List of validation losses recorded during training.
        patience_epochs (int): Number of epochs to wait before stopping if no improvement.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.

    Returns:
        bool: True if training should stop, False otherwise.
    """
    if len(valid_losses) < patience_epochs:
        return False
    recent_losses = valid_losses[-patience_epochs:]

    if all(x >= recent_losses[0] for x in recent_losses):
        return True

    if max(recent_losses) - min(recent_losses) < patience_loss:
        return True
    return False

inference_model(model, inference_loader, device)

Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

Parameters:

Name Type Description Default
model Module

The trained model to be used for inference.

required
inference_loader DataLoader

DataLoader for the inference dataset.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
inference_outputs ndarray

Outputs from the model on the inference dataset.

Source code in windscangeo\impl.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def inference_model(model, inference_loader, device):
    """
    Perform inference on the model using the provided DataLoader and return the outputs. Same as train_model but for a fixed given model.

    Args:
        model (torch.nn.Module): The trained model to be used for inference.
        inference_loader (torch.utils.data.DataLoader): DataLoader for the inference dataset.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        inference_outputs (numpy.ndarray): Outputs from the model on the inference dataset.
    """

    with torch.no_grad():  # Disable gradient calculation for inference

        inference_outputs = []

        for images in inference_loader:
            images = images.to(device)

            outputs = model(images).squeeze(-1)

            # Append outputs to the list
            inference_outputs.append(outputs)

        inference_outputs = torch.cat(inference_outputs, dim=0)
        inference_outputs = inference_outputs.cpu()
        inference_outputs = inference_outputs.numpy()

    return inference_outputs

manage_saved_models(directory)

Manage saved model files in the specified directory by deleting older epoch files. Keeps only the latest epoch file and deletes all others. From @ Jing Sun

Parameters:

Name Type Description Default
directory str

The directory where model files are saved.

required
Source code in windscangeo\impl.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def manage_saved_models(directory):  # From @Jing
    """
    Manage saved model files in the specified directory by deleting older epoch files.
    Keeps only the latest epoch file and deletes all others. From @ Jing Sun

    Args:
        directory (str): The directory where model files are saved.
    """

    pattern = re.compile(r"epoch_(\d+)\.pth")
    epoch_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            match = pattern.match(file)
            if match:
                epoch_num = int(match.group(1))
                file_path = os.path.join(root, file)
                epoch_files.append((file_path, epoch_num))

    # Check if there are more than 5 files
    if len(epoch_files) > 1:
        epoch_files.sort(key=lambda x: x[1])
        files_to_delete = len(epoch_files) - 1

        for i in range(files_to_delete):
            os.remove(epoch_files[i][0])

test_model(model, test_loader, criterion, device)

Evaluate the model on the test dataset and return the outputs, targets, and average loss.

Parameters:

Name Type Description Default
model Module

The trained model to be evaluated.

required
test_loader DataLoader

DataLoader for the test dataset.

required
criterion Module

Loss function to be used for evaluation.

required
device device

Device to run the model on (CPU or GPU).

required

Returns:

Name Type Description
test_outputs ndarray

Outputs from the model on the test dataset.

test_targets ndarray

Targets corresponding to the test outputs.

avg_test_loss float

Average loss on the test dataset.

Source code in windscangeo\impl.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def test_model(model, test_loader, criterion, device):
    """
    Evaluate the model on the test dataset and return the outputs, targets, and average loss.

    Args:
        model (torch.nn.Module): The trained model to be evaluated.
        test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
        criterion (torch.nn.Module): Loss function to be used for evaluation.
        device (torch.device): Device to run the model on (CPU or GPU).

    Returns:
        test_outputs (numpy.ndarray): Outputs from the model on the test dataset.
        test_targets (numpy.ndarray): Targets corresponding to the test outputs.
        avg_test_loss (float): Average loss on the test dataset.
    """

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for inference

        test_outputs = []
        test_targets = []
        test_loss = 0.0

        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            # Append outputs to the list
            test_outputs.append(outputs)
            test_targets.append(targets)

        avg_test_loss = test_loss / len(test_loader)
        print(f"EVAL : Test Loss: {avg_test_loss}")

        test_outputs = torch.cat(test_outputs, dim=0)
        test_outputs = test_outputs.cpu()
        test_outputs = test_outputs.numpy()

        test_targets = torch.cat(test_targets, dim=0)
        test_targets = test_targets.cpu()
        test_targets = test_targets.numpy()

    return test_outputs, test_targets, avg_test_loss

train_model(model, train_loader, val_loader, num_epochs, lr, weight_decay, criterion, device, optimizer_choice, patience_epochs, patience_loss, path_folder)

Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
train_loader DataLoader

DataLoader for the training dataset.

required
val_loader DataLoader

DataLoader for the validation dataset.

required
num_epochs int

Number of epochs to train the model.

required
lr float

Learning rate for the optimizer.

required
weight_decay float

Weight decay for the optimizer.

required
criterion Module

Loss function to be used.

required
device device

Device to run the model on (CPU or GPU).

required
optimizer_choice str

Choice of optimizer ('Adam', 'SGD', 'RMSprop').

required
patience_epochs int

Number of epochs to wait before stopping if no improvement in validation loss.

required
patience_loss float

Minimum change in validation loss to consider as an improvement.

required
path_folder str

Path to save the model checkpoints.

required

Returns:

Name Type Description
best_val_outputs ndarray

Best validation outputs from the model.

best_val_labels ndarray

Best validation labels corresponding to the outputs.

best_model Module

The best model based on validation loss.

train_losses list

List of training losses for each epoch.

val_losses list

List of validation losses for each epoch.

Source code in windscangeo\impl.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs,
    lr,
    weight_decay,
    criterion,
    device,
    optimizer_choice,
    patience_epochs,
    patience_loss,
    path_folder,
):

    """
    Train the model with the given parameters dictionary and save the best validation outputs, labels, and model.

    Args:
        model (torch.nn.Module): The model to be trained.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        num_epochs (int): Number of epochs to train the model.
        lr (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay for the optimizer.
        criterion (torch.nn.Module): Loss function to be used.
        device (torch.device): Device to run the model on (CPU or GPU).
        optimizer_choice (str): Choice of optimizer ('Adam', 'SGD', 'RMSprop').
        patience_epochs (int): Number of epochs to wait before stopping if no improvement in validation loss.
        patience_loss (float): Minimum change in validation loss to consider as an improvement.
        path_folder (str): Path to save the model checkpoints.

    Returns:
        best_val_outputs (numpy.ndarray): Best validation outputs from the model.
        best_val_labels (numpy.ndarray): Best validation labels corresponding to the outputs.
        best_model (torch.nn.Module): The best model based on validation loss.
        train_losses (list): List of training losses for each epoch.
        val_losses (list): List of validation losses for each epoch.
    """


    if optimizer_choice == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_choice == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Invalid optimizer choice. Please choose 'Adam' or 'SGD'.")

    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.8)

    # Initialize lists to store loss values and validation predictions
    train_losses = []
    val_losses = []
    best_val_loss = float("inf")
    best_val_outputs = None
    best_val_labels = None


    pbar = tqdm(range(num_epochs), desc="TRAIN : Training Progress")    
    for epoch in pbar:
        # Training Phase
        model.train()
        running_loss = 0.0

        for images, targets in train_loader:

            # Move data to GPU
            images = images.to(device)
            targets = targets.to(device)            

            # Forward pass
            optimizer.zero_grad()
            outputs = model(images).squeeze(-1)
            loss = criterion(outputs, targets)  # Ensure target shape matches output
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Calculate average training loss for the epoch
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        #print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}"})

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_outputs = []  # Temporary list to store outputs for this epoch
        val_labels = []  # Temporary list to store labels for this epoch
        with torch.no_grad():
            for images, targets in val_loader:
                # Move data to GPU
                images = images.to(device)
                targets = targets.to(device)

                # Get model output
                outputs = model(images).squeeze(-1)

                # Calculate loss
                loss = criterion(outputs, targets)  # Ensure target shape matches output
                val_loss += loss.item()

                # Append outputs and targets to lists
                val_outputs.append(outputs.cpu())  # Move to CPU for concatenation
                val_labels.append(targets.cpu())

        # Concatenate outputs and labels across all batches to ensure all samples are included
        val_outputs = torch.cat(val_outputs, dim=0).numpy()
        val_labels = torch.cat(val_labels, dim=0).numpy()

        # Calculate average validation loss for the epoch
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        #print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}")
        pbar.set_postfix({"Train Loss": f"{avg_train_loss:.4f}", "Val Loss": f"{avg_val_loss:.4f}"})

        # Check if this is the best validation loss so far and store best outputs/labels
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_outputs = val_outputs
            best_val_labels = val_labels

            best_model = model
            best_model_state = model.state_dict()
            torch.save(
                best_model_state, os.path.join(path_folder, f"./epoch_{epoch + 1}.pth")
            )


        manage_saved_models(path_folder)

        if early_stopping(val_losses, patience_epochs, patience_loss):
            return (
                best_val_outputs,
                best_val_labels,
                best_model,
                train_losses,
                val_losses,
            )
        scheduler.step()



    return best_val_outputs, best_val_labels, model, train_losses, val_losses