diff --git a/src/dswx_sar/save_mgrs_tiles.py b/src/dswx_sar/save_mgrs_tiles.py index 680c00a..f7f50d3 100644 --- a/src/dswx_sar/save_mgrs_tiles.py +++ b/src/dswx_sar/save_mgrs_tiles.py @@ -7,6 +7,7 @@ import os import time +from collections import Counter import geopandas as gpd import mgrs import numpy as np @@ -327,7 +328,8 @@ def crop_and_save_mgrs_tile( def get_intersecting_mgrs_tiles_list_from_db( image_tif, mgrs_collection_file, - track_number=None): + track_number=None, + burst_list=None): """Find and return a list of MGRS tiles that intersect a reference GeoTIFF file By searching in database @@ -341,6 +343,8 @@ def get_intersecting_mgrs_tiles_list_from_db( track_number : int, optional Track number (or relative orbit number) to specify MGRS tile collection + burst_list : list, optional + List of burst IDs to filter the MGRS tiles. Returns ------- @@ -349,10 +353,40 @@ def get_intersecting_mgrs_tiles_list_from_db( most_overlapped : GeoSeries The record of the MGRS tile with the maximum overlap area. """ + vector_gdf = gpd.read_file(mgrs_collection_file) + # Step 1: Filter by burst_list if provided + if burst_list is not None: + def burst_overlap(row): + row_bursts = ast.literal_eval(row['bursts']) + return len(set(burst_list).intersection(set(row_bursts))) + + vector_gdf['burst_overlap_count'] = vector_gdf.apply(burst_overlap, axis=1) + max_burst_overlap = vector_gdf['burst_overlap_count'].max() + vector_gdf = vector_gdf[vector_gdf['burst_overlap_count'] == max_burst_overlap] + + # If only one record matches, return it immediately + if len(vector_gdf) == 1: + logger.info('MGRS collection ID found from burst_list', vector_gdf['burst_overlap_count'] ) + mgrs_list = ast.literal_eval(vector_gdf.iloc[0]['mgrs_tiles']) + return list(set(mgrs_list)), vector_gdf.iloc[0] + + # Step 2: Filter by track_number, track_number + 1, and track_number - 1 if provided + if track_number is not None: + valid_track_numbers = [track_number, track_number + 1, track_number - 1] + vector_gdf = vector_gdf[ + vector_gdf['relative_orbit_number'].isin(valid_track_numbers) + ].to_crs("EPSG:4326") + + # If only one record matches, return it immediately + if len(vector_gdf) == 1: + mgrs_list = ast.literal_eval(vector_gdf.iloc[0]['mgrs_tiles']) + return list(set(mgrs_list)), vector_gdf.iloc[0] + else: + vector_gdf = vector_gdf.to_crs("EPSG:4326") + # Load the raster data with rasterio.open(image_tif) as src: epsg_code = src.crs.to_epsg() or 4326 - # Get bounds of the raster data left, bottom, right, top = src.bounds # Reproject to EPSG 4326 if the current EPSG is not 4326 @@ -371,7 +405,6 @@ def get_intersecting_mgrs_tiles_list_from_db( logger.info('The mosaic image crosses the antimeridian.') # Create a GeoDataFrame from the raster polygon if antimeridian_crossing_flag: - # Create a Polygon from the bounds raster_polygon_left = Polygon( [(left, bottom), (left, top), @@ -387,7 +420,6 @@ def get_intersecting_mgrs_tiles_list_from_db( raster_polygon_right], crs=4326) else: - # Create a Polygon from the bounds raster_polygon = Polygon( [(left, bottom), (left, top), @@ -397,18 +429,6 @@ def get_intersecting_mgrs_tiles_list_from_db( geometry=[raster_polygon], crs=4326) - # Load the vector data - vector_gdf = gpd.read_file(mgrs_collection_file) - - # If track number is given, then search MGRS tile collection with - # track number - if track_number is not None: - vector_gdf = vector_gdf[ - vector_gdf['relative_orbit_number'] == - track_number].to_crs("EPSG:4326") - else: - vector_gdf = vector_gdf.to_crs("EPSG:4326") - # Calculate the intersection intersection = gpd.overlay(raster_gdf, vector_gdf, @@ -416,11 +436,10 @@ def get_intersecting_mgrs_tiles_list_from_db( # Add a new column with the intersection area intersection['Area'] = intersection.to_crs(epsg=epsg_code).geometry.area + sorted_intersection = intersection.sort_values(by='Area', ascending=False) - # Find the polygon with the maximum intersection area - most_overlapped = intersection.loc[intersection['Area'].idxmax()] - - mgrs_list = ast.literal_eval(most_overlapped['mgrs_tiles']) + most_overlapped = sorted_intersection.iloc[0] if len(sorted_intersection) > 0 else None + mgrs_list = ast.literal_eval(most_overlapped['mgrs_tiles']) if most_overlapped is not None else [] return list(set(mgrs_list)), most_overlapped @@ -432,7 +451,7 @@ def get_mgrs_tiles_list_from_db(mgrs_collection_file, Parameters ---------- mgrs_collection_file : str - Path to the file containing the MGRS tile collection. + Path to the file containing the MGRS tile collection. This file should be readable by GeoPandas. mgrs_tile_collection_id : str The ID of the MGRS tile collection from which to retrieve the MGRS tiles. @@ -445,8 +464,7 @@ def get_mgrs_tiles_list_from_db(mgrs_collection_file, vector_gdf = gpd.read_file(mgrs_collection_file) most_overlapped = vector_gdf[ vector_gdf['mgrs_set_id'] == mgrs_tile_collection_id].iloc[0] - print(most_overlapped) - mgrs_list = ast.literal_eval(most_overlapped['mgrs_tiles']) + mgrs_list = ast.literal_eval(most_overlapped['mgrs_tiles']) return list(set(mgrs_list)), most_overlapped @@ -634,6 +652,7 @@ def run(cfg): logger.info(f'Number of bursts to process: {num_input_path}') date_str_list = [] + track_number_list = [] for input_dir in input_list: # Find HDF5 metadata metadata_path_iter = glob.iglob(f'{input_dir}/*{co_pol}*.tif') @@ -645,6 +664,10 @@ def run(cfg): track_number = int(tags['TRACK_NUMBER']) resolution = src.transform[0] date_str_list.append(date_str) + track_number_list.append(track_number) + counter = Counter(np.array(track_number_list)) + most_common = counter.most_common() + track_number = most_common[0][0] input_date_format = "%Y-%m-%dT%H:%M:%S" output_date_format = "%Y%m%dT%H%M%SZ" @@ -983,14 +1006,16 @@ def run(cfg): mgrs_meta_dict = {} if database_bool: - # In the case that mgrs_tile_collection_id is given + actual_burst_id = collect_burst_id(input_list, + DSWX_S1_POL_DICT['CO_POL']) + # In the case that mgrs_tile_collection_id is given # from input, then extract the MGRS list from database if input_mgrs_collection_id is not None: mgrs_tile_list, most_overlapped = \ get_mgrs_tiles_list_from_db( mgrs_collection_file=mgrs_collection_db_path, mgrs_tile_collection_id=input_mgrs_collection_id) - # In the case that mgrs_tile_collection_id is not given + # In the case that mgrs_tile_collection_id is not given # from input, then extract the MGRS list from database # using track number and area intersecting with image_tif else: @@ -998,14 +1023,14 @@ def run(cfg): get_intersecting_mgrs_tiles_list_from_db( mgrs_collection_file=mgrs_collection_db_path, image_tif=paths['final_water'], - track_number=track_number + track_number=track_number, + burst_list=actual_burst_id ) + track_number = most_overlapped['relative_orbit_number'] maximum_burst = most_overlapped['number_of_bursts'] # convert string to list expected_burst_list = ast.literal_eval(most_overlapped['bursts']) logger.info(f"Input RTCs are within {most_overlapped['mgrs_set_id']}") - actual_burst_id = collect_burst_id(input_list, - DSWX_S1_POL_DICT['CO_POL']) number_burst = len(actual_burst_id) mgrs_meta_dict['MGRS_COLLECTION_EXPECTED_NUMBER_OF_BURSTS'] = \ maximum_burst