import os
from itertools import product
import numpy as np
from affine import Affine
import matplotlib.pyplot as plt
import rasterio as rio
from rasterio import windows
from rasterio.enums import Resampling
from rasterio.transform import xy
from rasterio.warp import transform
from rasterio.transform import rowcol
import math
from rasterio.warp import reproject, Resampling
from rasterio.transform import from_origin
from ipywidgets import Output
from IPython.display import display, clear_output
import IPython
from IPython import get_ipython
import matplotlib
from matplotlib.widgets import Button
import ipywidgets as widgets
import warnings
if os.environ.get('DISPLAY', '') == '':
# Utilise 'agg' si aucune interface graphique n'est détectée
matplotlib.use('agg')
#else:
# Utilise 'tkagg' pour un affichage interactif standard
# matplotlib.use('tkagg')
[docs]
def get_crs_from_rdc(rdc_meta):
"""
Creates a CRS object based on the information extracted from the RDC file.
Args:
rdc_meta (dict): A dictionary containing metadata from the RDC file.
Returns:
CRS: A CRS object defined by rasterio.
"""
ref_system = rdc_meta.get("ref. system", "").lower()
if ref_system == "plane":
# If "plane" system, define a simple local CRS (e.g., UTM zone 33N as an example)
# You can modify the zone depending on your specific region
crs = CRS.from_proj4("+proj=utm +zone=33 +datum=WGS84 +units=m +no_defs")
else:
# If another CRS is provided, you can directly map it
# For example, if it is "WGS84", we use EPSG:4326
if ref_system == "wgs84":
crs = CRS.from_epsg(4326)
else:
# Default to a generic CRS if not recognized
crs = CRS.from_epsg(4326) # You can replace this with your preferred CRS
return crs
[docs]
def isfile_casesensitive(path):
if not os.path.isfile(path): return False # exit early
directory, filename = os.path.split(path)
return filename in os.listdir(directory)
[docs]
def find_rst_and_rdc(base_name):
"""
Finds the correct file names for .rst and .rdc files, handling case sensitivity.
Args:
base_name (str): The base name of the files (without extension).
Returns:
tuple: A tuple containing:
- rst_file (str): The full path to the .rst file, or None if not found.
- rdc_file (str): The full path to the .rdc file, or None if not found.
"""
extensions = [".rst", ".Rst", ".RST", ".rdc", ".Rdc", ".RDC"]
rst_file, rdc_file = None, None
for ext in extensions:
full_path = f"{base_name}{ext}"
if isfile_casesensitive(full_path):
print(' path is file : ',full_path)
if ext.lower() == ".rst":
rst_file = full_path
elif ext.lower() == ".rdc":
rdc_file = full_path
return rst_file, rdc_file
[docs]
def read_rst_with_rdc(rst_file, rdc_file, target_crs="EPSG:4326"):
"""
Reads an .rst file with its associated .rdc metadata file, reprojects the data
if the coordinate system is not standard, and returns the data array and metadata.
Args:
rst_file (str): Path to the .rst raster file.
rdc_file (str): Path to the associated .rdc metadata file.
target_crs (str): Target CRS for reprojection (default is "EPSG:4326").
Returns:
tuple: A tuple containing:
- data (numpy.ndarray): The raster data array.
- metadata (dict): Metadata including the CRS, transform, and other info.
"""
# Step 1: Parse the .rdc file to extract metadata
def parse_rdc(rdc_path):
"""Reads the .rdc file and returns a dictionary of metadata."""
with open(rdc_path, 'r') as file:
data = file.readlines()
meta = {}
for line in data:
if ":" in line:
key, value = line.split(":", 1)
meta[key.strip()] = value.strip()
return meta
rdc_meta = parse_rdc(rdc_file)
# Step 2: Extract relevant metadata from the .rdc file
ref_system = rdc_meta.get("ref. system", "").lower()
ref_units = rdc_meta.get("ref. units", "").lower()
min_x = float(rdc_meta.get("min. X", 0))
max_x = float(rdc_meta.get("max. X", 0))
min_y = float(rdc_meta.get("min. Y", 0))
max_y = float(rdc_meta.get("max. Y", 0))
resolution = float(rdc_meta.get("resolution", 1.0))
print('ref_system',ref_system)
print('ref_units',ref_units)
# Default CRS setup: Assuming "plane" means a local Cartesian system
# if ref_system == "plane":
if True:
rdc_meta = parse_rdc_meta(rdc_file)
crs = get_crs_from_rdc(rdc_meta)
else:
raise ValueError(f"Unsupported or unknown reference system: {ref_system}")
# Step 3: Read the .rst file
with rio.open(rst_file) as src:
# Update metadata to include the inferred CRS and transform
transform = rio.transform.from_bounds(
min_x, min_y, max_x, max_y, src.width, src.height
)
metadata = src.meta.copy()
metadata.update({
"crs": crs,
"transform": transform
})
data = src.read() # Read the first band
# Step 4: Reproject if the CRS is not the target CRS
if crs.to_string() != target_crs:
transform, width, height = calculate_default_transform(
crs, target_crs, src.width, src.height, *src.bounds
)
reprojected_data = np.empty((height, width), dtype=data.dtype)
reproject(
source=data,
destination=reprojected_data,
src_transform=transform,
src_crs=crs,
dst_transform=transform,
dst_crs=target_crs,
resampling=Resampling.nearest
)
metadata.update({
"crs": target_crs,
"transform": transform,
"width": width,
"height": height
})
return reprojected_data, metadata
# If no reprojection was needed, return the original data and metadata
return data, metadata
[docs]
def list_fles(path,ext):
"""
returns a list with sorted files in the folder "path" with extension "ext"
"""
listfiles = os.listdir(path)
# Filtrer ceux qui se terminent par .tif
files_ext = [f for f in listfiles if f.endswith(ext)]
# Trie par ordre alphabétique
files_ext.sort()
files_ext = [os.path.join(path, f) for f in files_ext]
return files_ext
[docs]
def normalize(band,percentile=2):
if band.dtype !='bool':
if percentile==0:
min=np.nanmin(band)
max=np.nanmax(band)
return (255.*np.clip((band-min)/(max-min),0.,1.)).astype(np.uint8)
else:
min=np.nanpercentile(band,percentile)
max=np.nanpercentile(band,100.-percentile)
return (255.*np.clip((band.astype(np.float64)-min)/(max-min).astype(np.float64),0.,1.)).astype(np.uint8)
else:
return band
[docs]
def get_percentile(band,percentile=2):
if percentile==0:
return band
else:
if band.dtype !='bool':
min=np.nanpercentile(band,percentile)
max=np.nanpercentile(band,100.-percentile)
return (np.clip(band,min,max))
else:
return band
[docs]
def get_tiles_recovery(ds, width=128, height=128,overlap=28):
"""
get window and transformation information for the
split of an image ds in snippets of size width x height with an overlap
"""
nols, nrows = ds.meta['width'], ds.meta['height']
offsets = product(range(-overlap, nols, width-overlap), range(-overlap, nrows, height-overlap))
big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
for col_off, row_off in offsets:
window =windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
transform = windows.transform(window, ds.transform)
yield window, transform
[docs]
def split_image_to_tiles(source_name, dest_name, size_row, size_col, overlap=0, name='sequence', verbose=0, name_tile=None):
"""
Split a georeferenced image into smaller tiles with optional overlap and save the resulting tiles and associated georeferencing information.
Parameters
----------
source_name : str
The file path to the input georeferenced image (e.g., a GeoTIFF) that needs to be split.
dest_name : str
The destination folder where the tiles and their associated georeferencing information will be saved.
size_row : int
The number of rows (row) for each tile. This defines the vertical size of each tile.
size_col : int
The number of columns (col) for each tile. This defines the horizontal size of each tile.
overlap : int, optional, default=0
The number of pixels to overlap between adjacent tiles. If set to 0, there will be no overlap between tiles.
name : str, optional, default='sequence'
The base name for the saved tiles. The tiles will be named sequentially (e.g., 'sequence_1.tif', 'sequence_2.tif', etc.) unless `name_tile` is provided.
verbose : int, optional, default=0
If set to 1 or higher, additional information about the process will be printed (e.g., progress messages).
name_tile : str, optional, default=None
If specified, the base name for the tiles will be overridden. Each tile will be named according to this value (e.g., 'tile_1.tif', 'tile_2.tif', etc.).
Returns
-------
None
The function saves the resulting tiles and georeferencing information into the specified destination folder. It does not return any value.
Examples
--------
>>> split_image_to_tiles('input_image.tif', 'tiles_folder', 256, 256)
>>> # Split an image into 256x256 tiles with no overlap:
>>>
>>> split_image_to_tiles('input_image.tif', 'tiles_folder', 256, 256, overlap=50)
>>> # Split an image into 256x256 tiles with 50 pixels overlap:
>>>
>>> split_image_to_tiles('input_image.tif', 'tiles_folder', 128, 128, name='part')
>>> # Split an image into 128x128 tiles and name the tiles 'part_1', 'part_2', etc.:
Notes
-----
- The function reads a georeferenced image, splits it into smaller tiles of size `size_row` x `size_col`, and saves each tile with its associated georeferencing information.
- If an overlap is specified, the tiles will include the specified number of pixels from adjacent tiles, ensuring that there is continuity across the tiles when reconstructed.
"""
if name_tile is None:
generic_name=os.path.splitext(os.path.split(source_name)[1])[0]
else:
generic_name = name_tile
if name=='sequence':
output_filename='%s/%s_tiles_{:05d}.tif'%(dest_name,generic_name)
else:
output_filename='%s/%s_tiles_{:05d}-{:05d}.tif'%(dest_name,generic_name)
i=0
#out_path=os.path.dirname(dest_name)
with rio.open(source_name) as inds:
tile_width, tile_height = size_row, size_col
meta = inds.meta.copy()
if verbose ==1:
print('size of the tile : %d x %d'%(meta['width'], meta['height']))
meta['driver']='GTiff'
for window, transform in get_tiles_recovery(inds, size_row, size_col,overlap):
meta['transform'] = transform
meta['width'], meta['height'] = window.width, window.height
# outpath = os.path.join(out_path, output_filename.format(int(window.col_off), int(window.row_off)))
# outpath = output_filename.format(int(window.col_off), int(window.row_off))
if name == 'sequence':
outpath = output_filename.format(int(i))
else:
outpath = output_filename.format(int(window.col_off), int(window.row_off))
if ((meta['width']==size_row) and (meta['height']==size_col)):
i=i+1
with rio.open(outpath, 'w', **meta) as outds:
outds.write(inds.read(window=window))
return 1
[docs]
def resampling_source(source_name, final_resolution, dest_name=None, method='cubic_spline', channel_first=True):
"""
Resamples a GeoTIFF image to a specified resolution.
Args:
source_name (str): Name of the input GeoTIFF image.
final_resolution (float): Desired resolution in meters.
dest_name (str, optional): Name of the output resampled image. Defaults to None.
method (str, optional): Resampling method. Defaults to "cubic_spline".
See details here:
https://rasterio.readthedocs.io/en/stable/api/rasterio.enums.html#rasterio.enums.Resampling.
channel_first (bool, optional): If True, outputs an image with shape
(bands x rows x cols). If False, outputs an image with shape
(rows x cols x bands). Defaults to True.
Returns:
- image (numpy.ndarray): The resampled image.
- meta (dict): Metadata of the resampled image.
"""
if method == 'cubic_spline':
resample_method = Resampling.cubic_spline
elif method =='nearest':
resample_method = Resampling.nearest
elif method =='bilinear':
resample_method = Resampling.bilinear
elif method =='cubic':
resample_method = Resampling.cubic
elif method =='lanczos':
resample_method = Resampling.lanczos
elif method =='average':
resample_method = Resampling.average
elif method =='mode':
resample_method = Resampling.mode
elif method =='max':
resample_method = Resampling.max
elif method =='min':
resample_method = Resampling.min
elif method =='med':
resample_method = Resampling.med
elif method =='sum':
resample_method = Resampling.sum
elif method =='q1':
resample_method = Resampling.q1
elif method =='q3':
resample_method = Resampling.q3
else:
raise ValueError("Error : unknown interpolation method %s"%method)
with rio.open(source_name) as dataset:
resolution_originale=dataset.meta['transform'][0]
data_bool=False
if dataset.meta['dtype'] == 'bool':
data_bool=True
resample_method = Resampling.nearest
#print('Image : ',source_name,'initial resolution : ',resolution_originale,' meters')
coef_multiplication = np.float64(resolution_originale/final_resolution)
#print('The coef to apply to reach %f meters is %f'%(final_resolution,coef_multiplication))
data = dataset.read(
out_shape=(
dataset.count,
int(dataset.height * coef_multiplication),
int(dataset.width * coef_multiplication)
),
resampling=resample_method
)
transform = dataset.transform * dataset.transform.scale(
(dataset.width / data.shape[-1]),
(dataset.height / data.shape[-2])
)
meta=dataset.meta.copy()
meta['width']=int(dataset.width * coef_multiplication)
meta['height']=int(dataset.height * coef_multiplication)
meta['transform']=transform
if dest_name is not None:
folder=os.path.split(dest_name)[0]
if os.path.exists(folder) is False and folder != '':
os.makedirs(folder)
with rio.open(dest_name,'w',**meta) as dst:
dst.write(data)
if channel_first is True:
return data,meta
else:
return np.rollaxis(data,0,3),meta
[docs]
def resample_image_with_resolution(
src_filename, resolution, method='cubic_spline',
dest_name=None, channel_first=True):
"""
Resamples an image to a specified resolution.
Args:
src_filename (str): Name of the input GeoTIFF image.
resolution (float): Desired resolution in meters.
method (str, optional): Resampling method. Defaults to "cubic_spline".
See details here:
https://rasterio.readthedocs.io/en/stable/api/rasterio.enums.html#rasterio.enums.Resampling.
dest_name (str, optional): Name of the output resampled image. Defaults to None.
channel_first (bool, optional): If True, outputs an image with shape
(bands x rows x cols). If False, outputs an image with shape
(rows x cols x bands). Defaults to True.
Returns:
tuple: A tuple containing:
- image (numpy.ndarray): The resampled image.
- meta (dict): Metadata of the resampled image.
"""
# Define resampling method based on input string
if method == 'cubic_spline':
resample_method = Resampling.cubic_spline
elif method == 'nearest':
resample_method = Resampling.nearest
elif method == 'bilinear':
resample_method = Resampling.bilinear
elif method == 'cubic':
resample_method = Resampling.cubic
elif method == 'lanczos':
resample_method = Resampling.lanczos
elif method == 'average':
resample_method = Resampling.average
elif method == 'mode':
resample_method = Resampling.mode
elif method == 'max':
resample_method = Resampling.max
elif method == 'min':
resample_method = Resampling.min
elif method == 'med':
resample_method = Resampling.med
elif method == 'sum':
resample_method = Resampling.sum
elif method == 'q1':
resample_method = Resampling.q1
elif method == 'q3':
resample_method = Resampling.q3
else:
raise ValueError("Error: unknown interpolation method %s" % method)
# Ouvrir le fichier source
with rio.open(src_filename) as src:
src_array = src.read()
src_meta = src.meta
src_transform = src.transform
new_width = int((src.bounds.right - src.bounds.left) / resolution)
new_height = int((src.bounds.top - src.bounds.bottom) / resolution)
new_transform = from_origin(src.bounds.left, src.bounds.top, resolution, resolution)
# Mettre à jour les métadonnées de destination
dst_meta = src_meta.copy()
dst_meta.update({
'height': new_height,
'width': new_width,
'transform': new_transform
})
data_bool=False
if src_array.dtype == 'bool':
data_bool=True
print("Traitement d'une image booléenne")
# Convertir les booléens en entiers pour le rééchantillonnage (0 et 1)
src_array = src_array.astype(np.uint8)
# Utiliser le rééchantillonnage 'nearest' pour les booléens
resample_method = Resampling.nearest
# Créer un tableau de destination (rééchantillonné)
dst_array = np.empty((src_meta['count'], new_height, new_width), dtype=src_array.dtype)
# Rééchantillonnage
reproject(
source=src_array,
destination=dst_array,
src_transform=src_transform,
dst_transform=new_transform,
src_crs=src_meta['crs'],
dst_crs=src_meta['crs'],
resampling=resample_method
)
# Si les données originales étaient booléennes, reconvertir en booléen
if data_bool:
dst_array = dst_array.astype(bool)
if dest_name is not None:
folder=os.path.split(dest_name)[0]
if os.path.exists(folder) is False and folder != '':
os.makedirs(folder)
with rio.open(dest_name,'w',**dst_meta) as dst:
dst.write(dst_array)
if channel_first is True:
return dst_array, dst_meta
else:
return np.rollaxis(dst_array,0,3),dst_meta
[docs]
def resampling_image(image, meta, final_resolution, dest_name=None, method='cubic_spline', channel_first=True):
"""
Resampling of an image
- image: the input image as a numpy array
- meta: metadata of the image (dictionary)
- final_resolution: the desired resolution (in meters)
- dest_name: (optional) name of the resampled image file
- method: resampling method (default: cubic_spline)
- channel_first (optional, default True): output an image of shape (bands x height x width)
otherwise: (height x width x bands)
- return: resampled image, updated meta
"""
# Define resampling method based on input string
if method == 'cubic_spline':
resample_method = Resampling.cubic_spline
elif method == 'nearest':
resample_method = Resampling.nearest
elif method == 'bilinear':
resample_method = Resampling.bilinear
elif method == 'cubic':
resample_method = Resampling.cubic
elif method == 'lanczos':
resample_method = Resampling.lanczos
elif method == 'average':
resample_method = Resampling.average
elif method == 'mode':
resample_method = Resampling.mode
elif method == 'gauss':
resample_method = Resampling.gauss
elif method == 'max':
resample_method = Resampling.max
elif method == 'min':
resample_method = Resampling.min
elif method == 'med':
resample_method = Resampling.med
elif method == 'sum':
resample_method = Resampling.sum
elif method == 'q1':
resample_method = Resampling.q1
elif method == 'q3':
resample_method = Resampling.q3
else:
raise ValueError("Error: unknown interpolation method %s" % method)
data_bool=False
if image.dtype == 'bool':
data_bool=True
resample_method = Resampling.nearest
image = image.astype(np.uint8)
# Get original resolution from metadata
original_resolution = meta['transform'][0]
#print('Initial resolution: ', original_resolution, ' meters')
# Calculate the scaling factor for resampling
scale_factor = np.float64(original_resolution / final_resolution)
#print(f'The scaling factor to reach {final_resolution} meters is {scale_factor}')
# Perform resampling
if channel_first is False:
image=np.rollaxis(image,2,0)
bands, height, width = image.shape
new_height = int(height * scale_factor)
new_width = int(width * scale_factor)
# Use rasterio's resampling to resize the image
data = np.empty((bands, new_height, new_width), dtype=image.dtype)
for i in range(bands):
data[i] = rio.warp.reproject(
source=image[i],
destination=np.empty((new_height, new_width), dtype=image.dtype),
src_transform=meta['transform'],
dst_transform=meta['transform'] * meta['transform'].scale(
(width / new_width),
(height / new_height)
),
src_crs=meta['crs'], # Pass CRS from metadata
dst_crs=meta['crs'], # Use the same CRS for the destination
resampling=resample_method
)[0]
# Update metadata for the resampled image
new_transform = meta['transform'] * meta['transform'].scale(
(width / new_width),
(height / new_height)
)
new_meta = meta.copy()
new_meta.update({
'width': new_width,
'height': new_height,
'transform': new_transform
})
if data_bool:
data = data.astype(bool)
# Optionally save the resampled image
if dest_name is not None:
folder=os.path.split(dest_name)[0]
if os.path.exists(folder) is False and folder != '':
os.makedirs(folder)
with rio.open(dest_name, 'w', **new_meta) as dst:
dst.write(data)
# Adjust output shape based on channel_first flag
if channel_first:
return data, new_meta
else:
return np.rollaxis(data, 0, 3), new_meta
[docs]
def numpy_to_string_list(a, delimiter=', '):
"""Converts a NumPy object into a list of strings, with an optional delimiter.
Args:
a: A NumPy object (integer, array of integers or strings).
delimiter: The delimiter to use between elements of the list (default: ', ').
Returns:
A list of strings.
"""
if isinstance(a, str):
return [a]
if isinstance(a, int):
return [str(a)]
# On convertit en tableau NumPy si ce n'est pas déjà le cas
a = np.array(a)
# On convertit chaque élément en chaîne de caractères et on l'ajoute à une liste
return [str(item) for item in a]
[docs]
def latlon_to_pixels(meta, lat, lon):
"""
Converts latitude/longitude coordinates to pixel coordinates (i, j).
Arguments:
- meta: metadata dictionary (from src.meta)
- lat: latitude
- lon: longitude
Returns:
- pix_i, pix_j: pixel indices (row, column) corresponding to the input coordinates
"""
try:
transform_affine = meta['transform']
crs = meta['crs']
x, y = transform({'init': 'EPSG:4326'}, crs, [lon], [lat])
pix_j, pix_i = rowcol(transform_affine, x[0], y[0])
except Exception as e:
warnings.warn("Unknow projection system. Can not deal anymore with reprojection. Some function may not work ")
pix_j, pix_i = 0,0
return pix_j, pix_i
[docs]
def pixels_to_latlon(meta, pix_i, pix_j):
"""
Converts pixel coordinates to latitude/longitude.
Arguments:
- meta: metadata dictionary (from src.meta)
- pix_i: row (i) coordinate of the pixel
- pix_j: column (j) coordinate of the pixel
Returns:
- lat, lon: corresponding latitude and longitude
"""
try:
transform_affine = meta['transform']
crs = meta['crs']
x, y = xy(transform_affine, pix_i, pix_j)
lon, lat = transform(crs, {'init': 'EPSG:4326'}, [x], [y])
lon=lon[0]
lat=lat[0]
except Exception as e:
warnings.warn("Unknow projection system. Can not deal anymore with reprojection. Some function may not work ")
lon, lat = 0.,0.
return lat, lon
[docs]
def reindex_dictionary(dictionary, keys_to_remove):
"""
Reindexes a dictionary after removing specified keys.
Args:
dictionary: The dictionary to modify.
keys_to_remove: A list of keys to remove.
Returns:
A new dictionary with the removed keys and reindexed values.
"""
new_dict = {}
index = 1
for key, value in dictionary.items():
if key not in keys_to_remove:
new_dict[key] = index
index += 1
return new_dict
[docs]
def reindex_dictionary_keep_order(dictionary, keys_to_keep):
"""
Reindexes a dictionary keeping only specified keys and preserving their order.
Args:
dictionary: The dictionary to modify.
keys_to_keep: A list of keys to keep.
Returns:
A new dictionary with the specified keys and reindexed values.
"""
new_dict = {}
index = 1
for key in keys_to_keep:
if key in dictionary:
new_dict[key] = index
index += 1
return new_dict
[docs]
def reorder_dict_by_values(dictionary):
"""
Reorders a dictionary by its values in ascending order.
Args:
dictionary: The dictionary to reorder.
Returns:
A new dictionary with items ordered by their values.
"""
return dict(sorted(dictionary.items(), key=lambda item: item[1]))
[docs]
def lowest_value_dict(dictionary):
"""
returns the lowest in a dictionary
"""
if len(dictionary)==0:
return 1
else:
return dict(sorted(dictionary.items(), key=lambda item: item[1]))[list(dict(sorted(dictionary.items(), key=lambda item: item[1])).keys())[0]]
[docs]
def has_duplicates(tab):
"""Verify if a table has similar values
Args:
tab: the table
Returns:
True if it contains similar values, False else
"""
return len(set(tab)) != len(tab)
[docs]
def split_keys(names_dict, key):
"""
Returns two lists of keys: one with keys before a given key and one with keys after it.
Parameters
----------
names_dict : dict
Dictionary containing keys and their values.
key : str
The key used to split the lists of keys before and after.
Returns
-------
list, list
Two lists: one with keys before the given key and one with keys after.
"""
keys = list(names_dict.keys()) # Get the list of keys
if key not in keys:
raise ValueError("The provided key is not in the dictionary.")
index = keys.index(key) # Find the index of the given key
before_keys = keys[:index+1] # Keys before the given key
after_keys = keys[index + 1:] # Keys after the given key
return before_keys, after_keys
[docs]
def initialize_dict(n):
"""
create a dictionnary such as
{"1":1,"2":2,...,"n":n}
"""
names={}
for i in range(n):
names[str(i+1)]=i+1
return names
[docs]
def add_ordered_key(dictionary_input, key_name=None):
"""
Adds a key to the end of an ordered dictionary, following specific rules.
Args:
dictionary: The ordered dictionary to modify.
key_name: The name of the new key (optional).
Returns:
The modified dictionary with the new key.
"""
dictionary = dictionary_input.copy()
# If no key name is provided, generate one automatically
if key_name is None:
# Check if at least one key contains a number
contains_numbers = any(any(char.isdigit() for char in key) for key in dictionary)
if contains_numbers:
# Extract numeric parts from keys and find the max
numeric_parts = [int(key) for key in dictionary if all(char.isdigit() for char in key)]
max_numeric_key = max(numeric_parts) if numeric_parts else 0 # Handle case where no numeric part
if max_numeric_key==0:
key_name = str(len(dictionary) + 1)
else:
key_name = str(max_numeric_key + 1) # Convert to string directly
else:
# No key contains a number, simply use the index
key_name = str(len(dictionary) + 1) # Convert to string directly
# Add the new key with the next value
new_value = max(dictionary.values(), default=0) + 1
dictionary[key_name] = new_value
return dictionary
[docs]
def concat_dicts_with_keys(dict1, dict2):
"""
Concatenates two dictionaries, renaming duplicate keys with a suffix and assigning sequential values.
Args:
dict1: The first dictionary.
dict2: The second dictionary.
Returns:
A new dictionary with renamed keys and sequential values.
"""
# Create an empty dictionary to store the result
new_dict = {}
# Iterate over the first dictionary and add keys with '_1' suffix
for key in dict1:
new_dict[f"{key}_1"] = 0
# Iterate over the second dictionary and add keys with '_2' suffix
for key in dict2:
new_dict[f"{key}_2"] = 0
# Assign sequential values to the keys, starting from 1
for i, key in enumerate(new_dict, start=1):
new_dict[key] = i
return new_dict
[docs]
def has_common_key(dict1, dict2):
"""
Checks if there is at least one common key between two dictionaries.
Args:
dict1 (dict): The first dictionary.
dict2 (dict): The second dictionary.
Returns:
bool: True if there is at least one common key, False otherwise.
"""
# Convert the keys of both dictionaries to sets and find the intersection.
# If the intersection is non-empty, it means there is a common key.
return bool(set(dict1.keys()) & set(dict2.keys()))
[docs]
def concat_dicts_with_keys_unmodified(dict1, dict2):
"""
Concatenates two dictionaries, renaming duplicate keys with a suffix and assigning sequential values.
Args:
dict1: The first dictionary.
dict2: The second dictionary.
Returns:
A new dictionary with renamed keys and sequential values.
"""
# Create an empty dictionary to store the result
new_dict = {}
# Iterate over the first dictionary and add keys with '_1' suffix
for key in dict1:
new_dict[f"{key}"] = 0
# Iterate over the second dictionary and add keys with '_2' suffix
for key in dict2:
new_dict[f"{key}"] = 0
# Assign sequential values to the keys, starting from 1
for i, key in enumerate(new_dict, start=1):
new_dict[key] = i
return new_dict
[docs]
def latlon2meters(lon1, lat1, lon2, lat2):
# Earth's radius
R = 6378.137
dLat = math.radians(lat2 - lat1)
dLon = math.radians(lon2 - lon1)
# Haversine
a = math.sin(dLon / 2) ** 2 + math.cos(math.radians(lon1)) * math.cos(math.radians(lon2)) * math.sin(dLat / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
d = R * c
return d * 1000
[docs]
def calculate_bounds(meta):
"""
Compute boundaries from meta_informations
Input : meta informations
Output : bounding bos
"""
# Extraire la transformation affine et les dimensions de l'image
transform = meta['transform']
width = meta['width']
height = meta['height']
# Calculer les coordonnées géographiques des quatre coins de l'image
# Les coins en (0, 0) pour le coin supérieur gauche et (width, height) pour le coin inférieur droit
left, top = transform * (0, 0)
right, bottom = transform * (width, height)
# Créer un objet bounding box avec ces coordonnées
bounds = rio.coords.BoundingBox(left, bottom, right, top)
return bounds
[docs]
def ajust_sizes(im1,im2):
im1common=im1.copy()
im2common=im2.copy()
row1=im1common.shape[0]
row2=im2common.shape[0]
col1=im1common.shape[1]
col2=im2common.shape[1]
offset_col=col2-col1
offset_row=row2-row1
if offset_col != 0 or offset_row !=0:
lat1,lon1=im2common.pixel2latlon(im2common.shape[0],im2common.shape[1])
lat2,lon2=im1common.pixel2latlon(im1common.shape[0],im1common.shape[1])
diff_end=latlon2meters(lat1, lon1, lat2, lon2)
lat1,lon1=im2common.pixel2latlon(0,0)
lat2,lon2=im1common.pixel2latlon(0,0)
diff_deb=latlon2meters(lat1, lon1, lat2, lon2)
row1=im1common.shape[0]
row2=im2common.shape[0]
col1=im1common.shape[1]
col2=im2common.shape[1]
if ((row1!= row2) or (col1!= col2)):
if diff_deb < diff_end:
im1common.crop(0,np.min((row1,row2)),0,np.min((col1,col2)),inplace=True)
im2common.crop(0,np.min((row1,row2)),0,np.min((col1,col2)),inplace=True)
else:
if row1<=row2:
deb_row1=0
end_row1=im1common.shape[0]
deb_row2=(row2-row1)
end_row2=im2common.shape[0]
else:
deb_row2=0
end_row2=im2common.shape[0]
deb_row1=(row1-row2)
end_row1=im1common.shape[0]
if col1 <= col2:
deb_col1=0
end_col1=im1common.shape[1]
deb_col2=(col2-col1)
end_col2=im2common.shape[1]
else:
deb_col2=0
end_col2=im2common.shape[1]
deb_col1=(col1-col2)
end_col1=im1common.shape[1]
im1common.crop(deb_row1,end_row1,deb_col1,end_col1,inplace=True)
im2common.crop(deb_row2,end_row2,deb_col2,end_col2,inplace=True)
return im1common,im2common
[docs]
def is_value_compatible(value, dtype_str):
"""
Checks if the given value is compatible with the specified data type (provided as a string).
Parameters:
- value: the value to be checked
- dtype_str: the data type as a string (e.g., 'bool', 'uint8', 'int32', etc.)
Returns:
- True if the value is compatible with the data type, False otherwise
"""
# Convert the string to a numpy data type
try:
dtype = np.dtype(dtype_str)
except TypeError:
raise ValueError(f"The type '{dtype_str}' is not recognized.")
# Check limits based on the data type
if np.issubdtype(dtype, np.integer):
# For integer types, check if the value is an integer and within the type's limits
if isinstance(value, (int, np.integer)):
dtype_info = np.iinfo(dtype)
return dtype_info.min <= value <= dtype_info.max
else:
return False # Value is not an integer and thus incompatible
elif np.issubdtype(dtype, np.floating):
# For floating-point types, check if the value is a float and within the type's limits
if isinstance(value, (float, np.floating)):
dtype_info = np.finfo(dtype)
return dtype_info.min <= value <= dtype_info.max
else:
return False # Value is not a float and thus incompatible
elif np.issubdtype(dtype, np.bool_):
# Boolean values are limited to True/False or 1/0
return value in [0, 1, False, True]
else:
raise ValueError(f"The type '{dtype_str}' is not supported.")
[docs]
def adapt_nodata(dtype_str, nodata):
"""
Adjusts the nodata value if it's not compatible with the specified data type.
Parameters:
- dtype_str: the target data type as a string (e.g., 'bool', 'uint8', 'int32', etc.)
- nodata: the current nodata value
Returns:
- The adjusted nodata value, compatible with dtype_str
"""
# Check if the current nodata value is compatible with the specified data type
if nodata is None:
return nodata
else:
if is_value_compatible(nodata, dtype_str):
return nodata # If compatible, return the existing nodata value as is
# Define a new nodata value based on the target data type
dtype = np.dtype(dtype_str)
if np.issubdtype(dtype, np.integer):
# Set a safe nodata value for integer types (use min value for signed integers to avoid overflow issues)
if np.issubdtype(dtype, np.signedinteger):
new_nodata = np.iinfo(dtype).min
else: # For unsigned integer types
new_nodata = np.iinfo(dtype).max
elif np.issubdtype(dtype, np.floating):
# For float types, use NaN (Not a Number) as the default nodata
new_nodata = np.nan
elif np.issubdtype(dtype, np.bool_):
# For boolean, set nodata to 0 (False)
new_nodata = 0
else:
raise ValueError(f"The type '{dtype_str}' is not supported for nodata adaptation.")
return new_nodata
[docs]
def reset_matplotlib(mode='inline'):
plt.close('all')
#plt.clf()
ipython = get_ipython()
if ipython is not None:
try:
ipython.run_line_magic('matplotlib', mode)
except Exception as e:
print(f"Warning: Impossible to run run_line_magic : {e}")
[docs]
def is_notebook():
try:
from IPython import get_ipython
return 'IPKernelApp' in get_ipython().config
except:
return False
[docs]
def plot_clic_spectra(im, imc, figsize=(15, 5), plot_legend=False, names=None,
title_im="Original image (click outside or finish button to stop)",
title_spectra="Spectra", xlabel="Bands", ylabel="Value",callback=None):
if is_notebook():
ipython = get_ipython()
ipython.run_line_magic('matplotlib', 'inline')
ipython.run_line_magic('matplotlib', 'widget')
else:
plt.ion() # Interactive mode for the standard shell
plt.close('all')
series_spectrales = []
val_i = []
val_j = []
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
fig.suptitle("")
ax1.imshow(imc)
ax1.set_title(title_im)
close_button = Button(plt.axes([0.04, 0.92, 0.1, 0.05]), 'Finish') # Adjusted position for top-left
global end_collect
end_collect = False
def close_figure(event=None):
fig.canvas.mpl_disconnect(cid)
plt.close(fig)
ready_label.value = "Click finished. Data is ready."
end_collect = True
# Call the callback if provided
if callback:
callback(series_spectrales, val_i, val_j)
reset_matplotlib()
close_button.on_clicked(close_figure)
def onclick(event):
# If the click is outside the image, stop collecting data
if event.inaxes != ax1:
close_figure()
return
i, j = int(event.ydata), int(event.xdata)
val_i.append(i)
val_j.append(j)
# Extract the spectral values at the clicked point
serie_spectrale = im[i, j, :] # Extract spectral data
series_spectrales.append(serie_spectrale)
# Update the spectrum display
ax2.clear()
for indice in range(len(series_spectrales)):
label = f'Series {indice} ({val_i[indice]},{val_j[indice]})'
if plot_legend:
ax2.plot(series_spectrales[indice], label=label)
ax2.legend()
else:
ax2.plot(series_spectrales[indice])
if names is not None:
ax2.set_xticks(range(len(names)))
ax2.set_xticklabels([key for key in names], rotation=45)
ax2.set_title(title_spectra)
ax2.set_xlabel(xlabel)
ax2.set_ylabel(ylabel)
fig.canvas.draw()
# Connect the click event
cid = fig.canvas.mpl_connect('button_press_event', onclick)
# Add a widget to indicate that a click is expected
ready_label = widgets.Label(value="Click on the image to retrieve spectra.")
display(ready_label)
# Display the figure and allow the user to click
plt.show()
# Return the data after the figure is closed
return series_spectrales, val_i, val_j,end_collect
[docs]
def dict_keys_to_list(d):
"""
Convert dictionary keys to a list of strings.
Parameters
----------
d : dict
The input dictionary.
Returns
-------
list of str
A list containing the keys of the dictionary as strings.
Example
-------
>>> dict_keys_to_list({'a': 1, 'b': 2, 'c': 3})
['a', 'b', 'c']
"""
return list(map(str, d.keys()))
[docs]
def match_dimensions(tab1, tab2):
if tab1.shape == tab2.shape:
return tab1 # Les dimensions sont identiques, on ne fait rien
elif tab1.shape[0] == 1 and tab1.shape[1:] == tab2.shape[1:]:
# Dupliquer la bande unique de tab1 pour correspondre à la dimension N de tab2
return np.repeat(tab1, tab2.shape[0], axis=0)
else:
raise ValueError("Les dimensions des tableaux ne sont pas compatibles.")
[docs]
def resize_array(array, new_shape):
old_rows, old_cols = array.shape
new_rows, new_cols = new_shape
row_idx = np.linspace(0, old_rows - 1, new_rows)
col_idx = np.linspace(0, old_cols - 1, new_cols)
interpolated_rows = np.array([np.interp(col_idx, np.arange(old_cols), row) for row in array])
resized_array = np.array([np.interp(row_idx, np.arange(old_rows), col) for col in interpolated_rows.T]).T
return resized_array
[docs]
def fusion_2classes(m1,m2):
"""
Fuse two mass function tables using a specified rule.
Parameters
----------
m1 : numpy.ndarray
A N x 3 array representing source 1.
Columns correspond to m(class 1), m(class 2), and m(class 1 ∪ class 2).
m2 : numpy.ndarray
A N x 3 array representing source 2.
Columns correspond to m(class 1), m(class 2), and m(class 1 ∪ class 2).
Returns
-------
numpy.ndarray
A N x 3 array containing the fused mass functions:
m(class 1), m(class 2), and m(class 1 ∪ class 2).
numpy.ndarray
A 1D array of length N containing the conflict values for each row.
Examples
--------
>>> fused, conflict = fuse_mass_functions(m1, m2)
"""
fusion = np.zeros(m1.shape)
# Hypothesis1, source 1
H1_1 = m1[:,0]
# Hypothesis1, source 2
H1_2 = m2[:, 0]
# Hypothesis2, source 1
H2_1 = m1[:,1]
# Hypothesis2, source 2
H2_2 = m2[:, 1]
# 1u2, source 1
H1U2_1= m1[:,2]
# 1u2, source 2
H1U2_2= m2[:,2]
# Conflict
conflit = H1_1*H2_2 + H2_1*H1_2
fusion[:,0] = (H1_1*H1_2 + H1_1*H1U2_2+H1_2*H1U2_1) / (1 -conflit)
fusion[:,1] = (H2_1*H2_2 + H2_1*H1U2_2+H2_2*H1U2_1) / (1 -conflit)
fusion[:,2] = (H1U2_2*H1U2_1) / (1-conflit)
return fusion, conflit
[docs]
def fusion_multisource(*args):
"""
Fuse multiple mass function tables from different sources.
Parameters
----------
*args : tuple of numpy.ndarray
Each argument is an N x 3 array representing a source.
Columns correspond to m(class 1), m(class 2), and m(class 1 ∪ class 2).
Returns
-------
numpy.ndarray
A N x 3 array containing the fused mass functions:
m(class 1), m(class 2), and m(class 1 ∪ class 2).
numpy.ndarray
A 1D array of length N containing the conflict values for each row.
Examples
--------
>>> fused, conflict = fusion_multisource(m1, m2, m3)
"""
fusion = np.zeros(args[0].shape)
fusion,conflit = fusion_2classes(args[0],args[1])
for i in range(2,len(args)):
fusion, conflit = fusion_2classes(fusion, args[i])
return fusion, conflit