diff --git a/operators/quantem-direct-ptycho/operator.json b/operators/quantem-direct-ptycho/operator.json index ab1d5c68..5abc3021 100644 --- a/operators/quantem-direct-ptycho/operator.json +++ b/operators/quantem-direct-ptycho/operator.json @@ -20,22 +20,6 @@ } ], "parameters": [ - { - "name": "calculation_frequency", - "label": "Calculation Frequency", - "type": "int", - "default": "100", - "description": "Number of frames to accumulate before recalculating the center and emitting a BF image.", - "required": true - }, - { - "name": "max_concurrent_scans", - "label": "Max Concurrent Scans", - "type": "int", - "default": "1", - "description": "Maximum number of scans to keep in memory simultaneously. Oldest scans are evicted when this limit is exceeded.", - "required": false - }, { "name": "accelerating_voltage", "label": "Accelerating voltage", @@ -61,11 +45,19 @@ "required": false }, { - "name": "initial_defocus", - "label": "STEM defocus", - "type": "float", - "default": "0.0", - "description": "The STEM defocus in nm.", + "name": "defocus_search_range_min", + "label": "Defocus search range minimum", + "type": "int", + "default": "0", + "description": "The defocus search range minimum in nanometers.", + "required": false + }, + { + "name": "defocus_search_range_max", + "label": "Defocus search range maximum", + "type": "int", + "default": "30", + "description": "The defocus search range maximum in nanometers.", "required": false }, { @@ -76,6 +68,14 @@ "description": "The rotation of the diffraction pattern on the detector in degrees.", "required": false }, + { + "name": "maximum_C12_magnitude", + "label": "Maximum C12 magnitude", + "type": "int", + "default": "10", + "description": "The maximum C12 magnitude in nanometers.", + "required": false + }, { "name": "crop_probes", "label": "Crop probes on each side", @@ -93,19 +93,19 @@ "required": false }, { - "name": "defocus_search_range_nm", - "label": "Defocus search range", - "type": "float", - "default": "50.0", - "description": "The defocus search range in nanometers. Unused if defocus is input.", + "name": "n_trials", + "label": "Number of trials for hyperparameter optimization", + "type": "int", + "default": "25", + "description": "The number of trials for hyperparameter optimization.", "required": false }, - { - "name": "maximum_C12_magnitude_nm", - "label": "Maximum C12 magnitude", + { + "name": "max_batch_size", + "label": "Maximum batch size", "type": "int", - "default": "2", - "description": "The maximum C12 magnitude in nanometers.", + "default": "10", + "description": "The maximum batch size for processing frames. Reduce if you run out of GPU memory.", "required": false }, { @@ -116,14 +116,6 @@ "options": ["parallax", "ssb", "icom"], "description": "The deconvolution kernel.", "required": false - }, - { - "name": "use_optimization", - "label": "Use optimization routine", - "type": "bool", - "default": true, - "description": "Use the optimization routine with initial parameters.", - "required": false } ], "parallel_config": { diff --git a/operators/quantem-direct-ptycho/run.py b/operators/quantem-direct-ptycho/run.py index 25d243b7..060a33c8 100644 --- a/operators/quantem-direct-ptycho/run.py +++ b/operators/quantem-direct-ptycho/run.py @@ -80,7 +80,7 @@ def quantem_direct_ptycho( scan_number = batch.header.scan_number # --- 2. Get or Create FrameAccumulator --- - max_concurrent_scans = int(parameters.get("max_concurrent_scans", 1)) + max_concurrent_scans = 1 # TODO: remove saving of old scans if scan_number not in accumulators: # Check if we need to evict old accumulators before creating new one @@ -114,56 +114,35 @@ def quantem_direct_ptycho( return None # --- 5. Perform Calculation --- - logger.info( - f"Scan {scan_number}: Triggering calculation after {accumulator.num_batches_added} messages." - ) - logger.info(f"Accumulator finished: {accumulator.finished}") - logger.info(f"Scan {scan_number}: Calculating ptycho images.") # Calculation parameters - probe_semiangle = parameters.get("probe_semiangle", 25.0) energy = parameters.get("accelerating_voltage", 300e3) - probe_step_size = parameters.get( - "probe_step_size", 0.1 - ) # test data set: 0.14383155 nm - crop_probes = parameters.get("crop_probes", 0) + probe_semiangle = parameters.get("probe_semiangle", 25.0) + # test data set: 0.14383155 nm probe step size + probe_step_size_nm = parameters.get("probe_step_size", 0.1) + probe_step_size_A = probe_step_size_nm * 10 upsampling_factor = parameters.get("upsampling_factor", 2) + n_trials = parameters.get("n_trials", 25) + max_batch_size = parameters.get("max_batch_size", 10) - # Parameters for optimize_hyperparameters function - initial_defocus_nm = parameters.get( - "initial_defocus", None - ) # in nanometers, can be None - if initial_defocus_nm is not None: - initial_defocus_nm = initial_defocus_nm - initial_defocus_A = initial_defocus_nm * 10 # convert to Angstroms - else: - initial_defocus_A = None - - diffraction_rotation_angle = parameters.get( - "diffraction_rotation_angle", None - ) # in degrees, can be None - if diffraction_rotation_angle is not None: - diffraction_rotation_angle = diffraction_rotation_angle - rotation_angle = diffraction_rotation_angle * np.pi / 180 # convert to radians - else: - rotation_angle = None + defocus_search_min_nm = parameters.get("defocus_search_range_min", 50) + defocus_search_max_nm = parameters.get("defocus_search_range_max", 50) + defocus_search_min_A = defocus_search_min_nm * 10 # convert to Angstroms + defocus_search_max_A = defocus_search_max_nm * 10 # convert to Angstroms + # Need to convert signs and order because of different conventions in FEI and quantem + defocus_search_range_A = (-defocus_search_max_A, -defocus_search_min_A) - defocus_search_range_nm = parameters.get( - "defocus_search_range", 50 - ) # in nanometers - defocus_search_range_A = defocus_search_range_nm * 10 # convert to Angstroms + # in degrees + diffraction_rotation_angle_deg = parameters.get("diffraction_rotation_angle", 0) + rotation_angle = diffraction_rotation_angle_deg * np.pi / 180 # convert to radians - maximum_C12_magnitude_nm = parameters.get( - "maximum_C12_magnitude", 10 - ) # in nanometers + maximum_C12_magnitude_nm = parameters.get("maximum_C12_magnitude", 10) maximum_C12_magnitude_A = maximum_C12_magnitude_nm * 10 # convert to Angstroms deconvolution_kernel = parameters.get("deconvolution_kernel", "parallax") - # Determine whether to use optimization or manual settings - use_optimization = bool(parameters.get("use_optimization", True)) - + crop_probes = parameters.get("crop_probes", 0) if crop_probes == 0: logger.info(f"Scan {scan_number}: No cropping of probes applied.") dense_data = accumulator[:, :-1, :, :].to_dense() ## remove the flyback column @@ -173,6 +152,7 @@ def quantem_direct_ptycho( crop_probes:-crop_probes, crop_probes : -crop_probes - 1, :, : ].to_dense() ## crop the edges if needed and remove the flyback column + # Convert SparseArray to Dataset4dstem dset = em.datastructures.Dataset4dstem.from_array(array=dense_data) logger.debug(f"dense shape = {dense_data.shape}") @@ -184,81 +164,64 @@ def quantem_direct_ptycho( dset.sampling[3] = probe_semiangle / probe_R dset.units[2:] = ["mrad", "mrad"] - dset.sampling[0] = ( - probe_step_size * 10 - ) ## convert to be Anggstrom for quantem. distiller will give nanometers. - dset.sampling[1] = probe_step_size * 10 + dset.sampling[0:2] = probe_step_size_A dset.units[0:2] = ["A", "A"] logger.info(f"Scan {scan_number}: Start direct ptycho") try: - # Initialize DirectPtychography with initial guesses - aberration_coefs = {} - if initial_defocus_A is not None: - aberration_coefs["C10"] = -initial_defocus_A # Note the negative sign + # Initialize DirectPtychography direct_ptycho = DirectPtychography.from_dataset4d( dset, energy=energy, semiangle_cutoff=probe_semiangle, device=QUANTEM_DEVICE, - aberration_coefs=aberration_coefs if aberration_coefs else None, - max_batch_size=10, + aberration_coefs={}, + max_batch_size=max_batch_size, rotation_angle=rotation_angle, # need radians ) - if use_optimization: - logger.info(f"Scan {scan_number}: Optimizing hyperparameters") - - # Build optimization aberration coefficients - opt_aberration_coefs = {} - if initial_defocus_A is None: - opt_aberration_coefs["C10"] = OptimizationParameter( - defocus_search_range_A, defocus_search_range_A - ) - else: - opt_aberration_coefs["C10"] = -initial_defocus_A - - opt_aberration_coefs["C12"] = OptimizationParameter( - 0, maximum_C12_magnitude_A - ) - opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2) - - # Build rotation angle optimization - if rotation_angle is None: - opt_rotation_angle = OptimizationParameter(0, np.pi) - else: - opt_rotation_angle = rotation_angle - - direct_ptycho.optimize_hyperparameters( - aberration_coefs=opt_aberration_coefs, - rotation_angle=opt_rotation_angle, - deconvolution_kernel=deconvolution_kernel, - n_trials=50, - max_batch_size=10, - ) - else: - logger.info(f"Scan {scan_number}: Using manual hyperparameter settings") + # Build optimization aberration coefficients + logger.info(f"Scan {scan_number}: Optimizing hyperparameters") + opt_aberration_coefs = {} + opt_aberration_coefs["C10"] = OptimizationParameter( + defocus_search_range_A[0], defocus_search_range_A[1] + ) + opt_aberration_coefs["C12"] = OptimizationParameter(0, maximum_C12_magnitude_A) + opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2) + + # Optimize hyperparameters + direct_ptycho.optimize_hyperparameters( + aberration_coefs=opt_aberration_coefs, + deconvolution_kernel="parallax", + n_trials=n_trials, + max_batch_size=max_batch_size, + ) - initial_parallax = direct_ptycho.reconstruct( + # Do reconstruction + logger.info(f"Scan {scan_number}: Starting reconstruction") + direct_ptycho.reconstruct( deconvolution_kernel=deconvolution_kernel, upsampling_factor=upsampling_factor, - max_batch_size=10, + max_batch_size=max_batch_size, ) # Process and return result logger.info(f"Scan {scan_number}: Reconstruction done") - output_bytes = initial_parallax.obj.tobytes() + output_bytes = direct_ptycho.obj.tobytes() output_meta = { "scan_number": scan_number, - "shape": initial_parallax.obj.shape, - "dtype": str(initial_parallax.obj.dtype), + "shape": direct_ptycho.obj.shape, + "dtype": str(direct_ptycho.obj.dtype), "source_operator": "quantem-direct-ptycho", - "direct_ptycho_params": {'C12': direct_ptycho.hyperparameter_state.optimized_aberrations['C12'], - 'phi12': direct_ptycho.hyperparameter_state.optimized_aberrations['phi12'], - 'C10': -direct_ptycho.aberration_coefs['C10'], - 'rotation_angle': direct_ptycho.rotation_angle, - }, + "direct_ptycho_params": { + "C12": direct_ptycho.hyperparameter_state.optimized_aberrations["C12"], + "phi12": direct_ptycho.hyperparameter_state.optimized_aberrations[ + "phi12" + ], + "C10": -direct_ptycho.aberration_coefs["C10"], + "rotation_angle": direct_ptycho.rotation_angle, + }, } header = MessageHeader(subject=MessageSubject.BYTES, meta=output_meta) return BytesMessage(header=header, data=output_bytes)