Skip to content

Commit 2fab2c8

Browse files
committed
Delete the temporary NETCDF4 file if converting to CDF5
Test to make sure it's been deleted.
1 parent a60477f commit 2fab2c8

2 files changed

Lines changed: 49 additions & 2 deletions

File tree

conda_package/mpas_tools/io.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import subprocess
33
import sys
44
from datetime import datetime
5+
from pathlib import Path
56

67
import netCDF4
78
import numpy
@@ -121,8 +122,10 @@ def write_netcdf(
121122
convert = format == 'NETCDF3_64BIT_DATA'
122123

123124
if convert:
124-
basename, extension = os.path.splitext(fileName)
125-
out_filename = f'{basename}.netcdf4{extension}'
125+
out_path = Path(fileName)
126+
out_filename = (
127+
out_path.parent / f'_tmp_{out_path.stem}.netcdf4{out_path.suffix}'
128+
)
126129
format = 'NETCDF4'
127130
if engine == 'scipy':
128131
# that's not going to work
@@ -151,6 +154,8 @@ def write_netcdf(
151154
)
152155
else:
153156
check_call(args, logger=logger)
157+
# delete the temporary NETCDF4 file
158+
os.remove(out_filename)
154159

155160

156161
def update_history(ds):

conda_package/tests/test_io.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def test_write_netcdf_cdf5_format(tmp_path):
4343
)
4444
# Should be cdf5 for NETCDF3_64BIT_DATA
4545
assert result.stdout.strip() == 'cdf5'
46+
# Check that the temporary file was deleted
47+
tmp_file = (
48+
out_file.parent / f'_tmp_{out_file.stem}.netcdf4{out_file.suffix}'
49+
)
50+
assert not os.path.exists(tmp_file)
4651

4752

4853
def test_write_netcdf_int64_conversion_and_attr(tmp_path):
@@ -58,3 +63,40 @@ def test_write_netcdf_int64_conversion_and_attr(tmp_path):
5863
# Attribute should be preserved
5964
assert ds2['foo'].attrs['myattr'] == 'testattr'
6065
ds2.close()
66+
67+
68+
def test_write_netcdf_fill_value(tmp_path):
69+
# Test that NaN values are written with correct fill value
70+
arr = np.array([1.0, np.nan, 3.0], dtype=np.float32)
71+
ds = xr.Dataset({'bar': (('x',), arr)})
72+
out_file = tmp_path / 'test_fill.nc'
73+
write_netcdf(ds, str(out_file))
74+
ds2 = xr.open_dataset(out_file)
75+
# The second value should be the default fill value for float32
76+
fill_value = ds2['bar'].encoding.get('_FillValue', None)
77+
assert fill_value is not None
78+
assert np.isnan(ds2['bar'].values[1])
79+
ds2.close()
80+
81+
82+
def test_write_netcdf_string_dim_name(tmp_path):
83+
# Test that custom char_dim_name is used in encoding
84+
arr = np.array([b'abc', b'def'])
85+
ds = xr.Dataset({'baz': (('x',), arr)})
86+
out_file = tmp_path / 'test_strdim.nc'
87+
write_netcdf(ds, str(out_file), char_dim_name='CustomStrLen')
88+
ds2 = xr.open_dataset(out_file)
89+
# Should have the variable and correct shape
90+
assert 'baz' in ds2.variables
91+
ds2.close()
92+
arr = np.array([1, 2, 3], dtype=np.int64)
93+
ds = xr.Dataset({'foo': (('x',), arr)})
94+
ds['foo'].attrs['myattr'] = 'testattr'
95+
out_file = tmp_path / 'test_int64.nc'
96+
write_netcdf(ds, str(out_file))
97+
ds2 = xr.open_dataset(out_file)
98+
# Should be int32, not int64
99+
assert ds2['foo'].dtype == np.int32
100+
# Attribute should be preserved
101+
assert ds2['foo'].attrs['myattr'] == 'testattr'
102+
ds2.close()

0 commit comments

Comments
 (0)