Skip to content
70 changes: 31 additions & 39 deletions operators/quantem-direct-ptycho/operator.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@
}
],
"parameters": [
{
"name": "calculation_frequency",
"label": "Calculation Frequency",
"type": "int",
"default": "100",
"description": "Number of frames to accumulate before recalculating the center and emitting a BF image.",
"required": true
},
{
"name": "max_concurrent_scans",
"label": "Max Concurrent Scans",
"type": "int",
"default": "1",
"description": "Maximum number of scans to keep in memory simultaneously. Oldest scans are evicted when this limit is exceeded.",
"required": false
},
{
"name": "accelerating_voltage",
"label": "Accelerating voltage",
Expand All @@ -61,11 +45,19 @@
"required": false
},
{
"name": "initial_defocus",
"label": "STEM defocus",
"type": "float",
"default": "0.0",
"description": "The STEM defocus in nm.",
"name": "defocus_search_range_min",
"label": "Defocus search range minimum",
"type": "int",
"default": "0",
"description": "The defocus search range minimum in nanometers.",
"required": false
},
{
"name": "defocus_search_range_max",
"label": "Defocus search range maximum",
"type": "int",
"default": "30",
"description": "The defocus search range maximum in nanometers.",
"required": false
},
{
Expand All @@ -76,6 +68,14 @@
"description": "The rotation of the diffraction pattern on the detector in degrees.",
"required": false
},
{
"name": "maximum_C12_magnitude",
"label": "Maximum C12 magnitude",
"type": "int",
"default": "10",
"description": "The maximum C12 magnitude in nanometers.",
"required": false
},
{
"name": "crop_probes",
"label": "Crop probes on each side",
Expand All @@ -93,19 +93,19 @@
"required": false
},
{
"name": "defocus_search_range_nm",
"label": "Defocus search range",
"type": "float",
"default": "50.0",
"description": "The defocus search range in nanometers. Unused if defocus is input.",
"name": "n_trials",
"label": "Number of trials for hyperparameter optimization",
"type": "int",
"default": "25",
"description": "The number of trials for hyperparameter optimization.",
"required": false
},
{
"name": "maximum_C12_magnitude_nm",
"label": "Maximum C12 magnitude",
{
"name": "max_batch_size",
"label": "Maximum batch size",
"type": "int",
"default": "2",
"description": "The maximum C12 magnitude in nanometers.",
"default": "10",
"description": "The maximum batch size for processing frames. Reduce if you run out of GPU memory.",
"required": false
},
{
Expand All @@ -116,14 +116,6 @@
"options": ["parallax", "ssb", "icom"],
"description": "The deconvolution kernel.",
"required": false
},
{
"name": "use_optimization",
"label": "Use optimization routine",
"type": "bool",
"default": true,
"description": "Use the optimization routine with initial parameters.",
"required": false
}
],
"parallel_config": {
Expand Down
145 changes: 54 additions & 91 deletions operators/quantem-direct-ptycho/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def quantem_direct_ptycho(
scan_number = batch.header.scan_number

# --- 2. Get or Create FrameAccumulator ---
max_concurrent_scans = int(parameters.get("max_concurrent_scans", 1))
max_concurrent_scans = 1 # TODO: remove saving of old scans

if scan_number not in accumulators:
# Check if we need to evict old accumulators before creating new one
Expand Down Expand Up @@ -114,56 +114,35 @@ def quantem_direct_ptycho(
return None

# --- 5. Perform Calculation ---
logger.info(
f"Scan {scan_number}: Triggering calculation after {accumulator.num_batches_added} messages."
)
logger.info(f"Accumulator finished: {accumulator.finished}")

logger.info(f"Scan {scan_number}: Calculating ptycho images.")

# Calculation parameters
probe_semiangle = parameters.get("probe_semiangle", 25.0)
energy = parameters.get("accelerating_voltage", 300e3)
probe_step_size = parameters.get(
"probe_step_size", 0.1
) # test data set: 0.14383155 nm
crop_probes = parameters.get("crop_probes", 0)
probe_semiangle = parameters.get("probe_semiangle", 25.0)
# test data set: 0.14383155 nm probe step size
probe_step_size_nm = parameters.get("probe_step_size", 0.1)
probe_step_size_A = probe_step_size_nm * 10
upsampling_factor = parameters.get("upsampling_factor", 2)
n_trials = parameters.get("n_trials", 25)
max_batch_size = parameters.get("max_batch_size", 10)

# Parameters for optimize_hyperparameters function
initial_defocus_nm = parameters.get(
"initial_defocus", None
) # in nanometers, can be None
if initial_defocus_nm is not None:
initial_defocus_nm = initial_defocus_nm
initial_defocus_A = initial_defocus_nm * 10 # convert to Angstroms
else:
initial_defocus_A = None

diffraction_rotation_angle = parameters.get(
"diffraction_rotation_angle", None
) # in degrees, can be None
if diffraction_rotation_angle is not None:
diffraction_rotation_angle = diffraction_rotation_angle
rotation_angle = diffraction_rotation_angle * np.pi / 180 # convert to radians
else:
rotation_angle = None
defocus_search_min_nm = parameters.get("defocus_search_range_min", 50)
defocus_search_max_nm = parameters.get("defocus_search_range_max", 50)
defocus_search_min_A = defocus_search_min_nm * 10 # convert to Angstroms
defocus_search_max_A = defocus_search_max_nm * 10 # convert to Angstroms
# Need to convert signs and order because of different conventions in FEI and quantem
defocus_search_range_A = (-defocus_search_max_A, -defocus_search_min_A)

defocus_search_range_nm = parameters.get(
"defocus_search_range", 50
) # in nanometers
defocus_search_range_A = defocus_search_range_nm * 10 # convert to Angstroms
# in degrees
diffraction_rotation_angle_deg = parameters.get("diffraction_rotation_angle", 0)
rotation_angle = diffraction_rotation_angle_deg * np.pi / 180 # convert to radians

maximum_C12_magnitude_nm = parameters.get(
"maximum_C12_magnitude", 10
) # in nanometers
maximum_C12_magnitude_nm = parameters.get("maximum_C12_magnitude", 10)
maximum_C12_magnitude_A = maximum_C12_magnitude_nm * 10 # convert to Angstroms

deconvolution_kernel = parameters.get("deconvolution_kernel", "parallax")

# Determine whether to use optimization or manual settings
use_optimization = bool(parameters.get("use_optimization", True))

crop_probes = parameters.get("crop_probes", 0)
if crop_probes == 0:
logger.info(f"Scan {scan_number}: No cropping of probes applied.")
dense_data = accumulator[:, :-1, :, :].to_dense() ## remove the flyback column
Expand All @@ -173,6 +152,7 @@ def quantem_direct_ptycho(
crop_probes:-crop_probes, crop_probes : -crop_probes - 1, :, :
].to_dense() ## crop the edges if needed and remove the flyback column

# Convert SparseArray to Dataset4dstem
dset = em.datastructures.Dataset4dstem.from_array(array=dense_data)
logger.debug(f"dense shape = {dense_data.shape}")

Expand All @@ -184,81 +164,64 @@ def quantem_direct_ptycho(
dset.sampling[3] = probe_semiangle / probe_R
dset.units[2:] = ["mrad", "mrad"]

dset.sampling[0] = (
probe_step_size * 10
) ## convert to be Anggstrom for quantem. distiller will give nanometers.
dset.sampling[1] = probe_step_size * 10
dset.sampling[0:2] = probe_step_size_A
dset.units[0:2] = ["A", "A"]
Comment thread
ercius marked this conversation as resolved.

logger.info(f"Scan {scan_number}: Start direct ptycho")
try:
# Initialize DirectPtychography with initial guesses
aberration_coefs = {}
if initial_defocus_A is not None:
aberration_coefs["C10"] = -initial_defocus_A # Note the negative sign
# Initialize DirectPtychography

direct_ptycho = DirectPtychography.from_dataset4d(
dset,
energy=energy,
semiangle_cutoff=probe_semiangle,
device=QUANTEM_DEVICE,
aberration_coefs=aberration_coefs if aberration_coefs else None,
max_batch_size=10,
aberration_coefs={},
max_batch_size=max_batch_size,
rotation_angle=rotation_angle, # need radians
)

if use_optimization:
logger.info(f"Scan {scan_number}: Optimizing hyperparameters")

# Build optimization aberration coefficients
opt_aberration_coefs = {}
if initial_defocus_A is None:
opt_aberration_coefs["C10"] = OptimizationParameter(
defocus_search_range_A, defocus_search_range_A
)
else:
opt_aberration_coefs["C10"] = -initial_defocus_A

opt_aberration_coefs["C12"] = OptimizationParameter(
0, maximum_C12_magnitude_A
)
opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2)

# Build rotation angle optimization
if rotation_angle is None:
opt_rotation_angle = OptimizationParameter(0, np.pi)
else:
opt_rotation_angle = rotation_angle

direct_ptycho.optimize_hyperparameters(
aberration_coefs=opt_aberration_coefs,
rotation_angle=opt_rotation_angle,
deconvolution_kernel=deconvolution_kernel,
n_trials=50,
max_batch_size=10,
)
else:
logger.info(f"Scan {scan_number}: Using manual hyperparameter settings")
# Build optimization aberration coefficients
logger.info(f"Scan {scan_number}: Optimizing hyperparameters")
opt_aberration_coefs = {}
opt_aberration_coefs["C10"] = OptimizationParameter(
defocus_search_range_A[0], defocus_search_range_A[1]
)
opt_aberration_coefs["C12"] = OptimizationParameter(0, maximum_C12_magnitude_A)
opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2)

# Optimize hyperparameters
direct_ptycho.optimize_hyperparameters(
aberration_coefs=opt_aberration_coefs,
deconvolution_kernel="parallax",
n_trials=n_trials,
max_batch_size=max_batch_size,
)

initial_parallax = direct_ptycho.reconstruct(
# Do reconstruction
logger.info(f"Scan {scan_number}: Starting reconstruction")
direct_ptycho.reconstruct(
deconvolution_kernel=deconvolution_kernel,
upsampling_factor=upsampling_factor,
max_batch_size=10,
max_batch_size=max_batch_size,
)

# Process and return result
logger.info(f"Scan {scan_number}: Reconstruction done")
output_bytes = initial_parallax.obj.tobytes()
output_bytes = direct_ptycho.obj.tobytes()
output_meta = {
"scan_number": scan_number,
"shape": initial_parallax.obj.shape,
"dtype": str(initial_parallax.obj.dtype),
"shape": direct_ptycho.obj.shape,
"dtype": str(direct_ptycho.obj.dtype),
"source_operator": "quantem-direct-ptycho",
"direct_ptycho_params": {'C12': direct_ptycho.hyperparameter_state.optimized_aberrations['C12'],
'phi12': direct_ptycho.hyperparameter_state.optimized_aberrations['phi12'],
'C10': -direct_ptycho.aberration_coefs['C10'],
'rotation_angle': direct_ptycho.rotation_angle,
},
"direct_ptycho_params": {
"C12": direct_ptycho.hyperparameter_state.optimized_aberrations["C12"],
"phi12": direct_ptycho.hyperparameter_state.optimized_aberrations[
"phi12"
],
"C10": -direct_ptycho.aberration_coefs["C10"],
"rotation_angle": direct_ptycho.rotation_angle,
},
}
header = MessageHeader(subject=MessageSubject.BYTES, meta=output_meta)
return BytesMessage(header=header, data=output_bytes)
Expand Down
Loading