Implement molecular connectivity checks#2002
Conversation
new helpers in core/utils/structure.py:
- reference_graph_from_atoms builds an nx.Graph (atomic_num node attr,
no bond order) from a reference conformer .xyz/.sdf/.mol.
- check_molecule_matches_reference does per-fragment graph isomorphism
vs the reference graph via nx.is_isomorphic with categorical_node_match
on atomic_num.
process_generated.py:
- Build the reference graph once per process_genarris_outputs_single
(one call per (mol_id, conf_id)) by looking for <conf_id>.{xyz,sdf,mol}
inside the conformer directory.
- Thread the reference graph into structure_to_row; emit a new column
validity.crystal_generated.molecule_matches_reference alongside
the existing validity.crystal_generated.correct_z.
relax.py / filter.py:
- Rename validity.connectivity_unchanged to
validity.crystal_relaxed.connectivity_unchanged for namespace
consistency with validity.crystal_relaxed.z_unchanged.
Drops the structures whose generation-time validity flags (validity.crystal_generated.correct_z, validity.crystal_generated.molecule_matches_reference) are False before deduplication and writing the parquets into the directory raw structures. Split into structures_df_filtered / problematic_structures_df, run deduplicate_structures only on the valid subset, mark problematic rows with group_index=-1, and optionally reintegrate. Default is False (preserve), matching get_post_relax_config.
…graphs in the relaxed polymorph - add validity.crystal_relaxed.molecule_matches_reference
lbluque
left a comment
There was a problem hiding this comment.
thx @jagritisahoo - these checks look much cleaner and more robust. I just left some small comments.
| # XYZ-loaded molecules have no unit cell (cell rank < 3), which makes | ||
| # AseAtomsAdaptor.get_structure raise LinAlgError on the singular | ||
| # lattice. Pad with a generously large cubic box so the molecule sits | ||
| # well inside and pymatgen can build a periodic Structure for JmolNN. | ||
| if np.linalg.matrix_rank(np.array(reference_atoms.cell)) < 3: | ||
| reference_atoms = reference_atoms.copy() | ||
| reference_atoms.cell = np.eye(3) * 30.0 | ||
| reference_atoms.center() | ||
| reference_atoms.pbc = True |
There was a problem hiding this comment.
Not necessary to change this, but pmg has a Molecule object as well, and likely JMolNN or equivalent that works on those. We might be able to just do that directly to avoid this.
There was a problem hiding this comment.
I think so
from pymatgen.analysis.graphs import MoleculeGraph
from pymatgen.analysis.local_env import JMolNN
molecule_graph = MoleculeGraph.with_local_env_strategy(my_molecule, JMolNN())
neighbors_list = molecule_graph.get_neighbors(site_index)
| # Build the nx.Graph (atomic_num node attr; undirected edges | ||
| graph = nx.Graph() | ||
| for i in range(n): | ||
| graph.add_node(i, atomic_num=structure[i].specie.number) | ||
| for i in range(n): | ||
| for entry in nn_info[i]: | ||
| j = entry["site_index"] | ||
| if i < j: | ||
| graph.add_edge(i, j) | ||
| return graph |
There was a problem hiding this comment.
Does this method work here to clean things up? https://networkx.org/documentation/stable/reference/generated/networkx.convert_matrix.from_numpy_array.html
There was a problem hiding this comment.
This is implemented now. Could you check the latest implementation @lbluque ?
| # Any difference indicates bond formation/breaking during relaxation | ||
| return np.array_equal(initial_nn_matrix, final_nn_matrix) | ||
|
|
||
| def check_connectivity_changes( |
There was a problem hiding this comment.
Do we still need this if we are using the method above check_molecule_matches_reference? That one looks a lot more robust, since ordering is not an issue. I suggest just removing this one altogether since it could lead to false negative being dropped due to site permutations.
There was a problem hiding this comment.
@lbluque yeah I think this is taken care by check_molecules_matches_reference and we probably do not need this. However, we can still keep the check_correct_z as a pre-filter post relaxation.
| if check_molecule_count: | ||
| initial_molecule_count = int( | ||
| csgraph.connected_components(initial_nn_matrix)[0] | ||
| ) | ||
| final_molecule_count = int(csgraph.connected_components(final_nn_matrix)[0]) | ||
| result["initial_molecule_count"] = initial_molecule_count | ||
| result["final_molecule_count"] = final_molecule_count | ||
| molecule_count_preserved = initial_molecule_count == final_molecule_count | ||
| result["molecule_count_preserved"] = molecule_count_preserved | ||
| if not molecule_count_preserved: | ||
| result["no_changes"] = False |
There was a problem hiding this comment.
This part isnt affected by site re-orderings, but is it different than checking z_changes with the function above? if not, i suggest just keeping one of them.
There was a problem hiding this comment.
@lbluque It is slightly different in the sense that the function check_correct_z compares the number of molecules in initial structure with the request Z which was used as an input to Genarris. This is checking the number of Z before relaxation and compares that to the Z after relaxation, essentially capturing bond breaking/fusing of molecules. Some of the code, such as building the nn_matrix can be abstracted out for sure.
|
So this doesn't include the code changes I made in https://github.com/fairinternal/generative_chemistry/pull/42/changes#diff-2e55eeb5b4c279f4e75e69b5ada39350bbf5541c0ff692de7c05bb2b72aad3fe which add stereochemistry checking. I think we really need these because we had a number of cases where the mispredictions were because the "lowest" energy was a different diastereomer and so it threw out all of the correct stereoisomer structures. Note that I don't believe I had a flag in my code for noting racemates (i.e. crystals when both enantiomers are present). We could easily just not worry about enantiomers and only diastereomers because if we had the full enantiomer of a crystal than the energy is the same, and I don't think genarris+UMA can produce non-racemate enatiomeric mixtures. |
…ges with reference-anchored correct_z + molecule_matches_reference - structure.py: refactor reference_graph_from_atoms and check_molecule_matches_reference to use nx.from_numpy_array; delete check_connectivity_changes + check_no_changes_in_covalent_matrix + check_no_changes_in_Z - relax.py: drop atoms_list_original snapshot; write validity.crystal_relaxed.correct_z (replaces .z_unchanged and .connectivity_unchanged) - filter.py: rewrite problematic mask to use correct_z + molecule_matches_reference; keep root_unrelaxed as opt-in toggle to recompute post-relax flags on the relaxed CIF when relax did not write them; add generated_structures_dir param for the reference graph - main.py: pass root_unrelaxed and generated_structures_dir to filter so the recompute runs by default
# Conflicts: # src/fairchem/applications/fastcsp/core/workflow/filter.py
- add jmolnn_adjacency(structure_or_atoms) helper refactor check_correct_z, reference_graph_from_atoms, and check_molecule_matches_reference to call it - add check_connectivity_unchanged(initial, final) for the strict, site-ordered init->final JmolNN bond-matrix comparison
| # Build undirected nx.Graph from the adjacency matrix and attach | ||
| # atomic_num as a per-node attribute (used by the categorical node | ||
| # match in check_molecule_matches_reference). | ||
| graph = nx.from_numpy_array(nn_matrix) |
There was a problem hiding this comment.
lg! I think you can pass the nodes directly to this function as well, instead of looping like below.
Implemented molecular connectivity checks at two stages in the FastCSP workflow
Post genarris stage
validity.crystal_generated.correct_zto the parquetvalidity.crystal_generated.molecule_matches_referenceto the parquetremove_problematicfor post genarris step to add an option of tagging and removing problematic structuresPost relaxation stage
validity.crystal_relaxed.correct_zvalidity.crystal_relaxed.molecule_matches_referenceto the parquetvalidity.crystal_relaxed.connectivity_unchanged