Crystal Graph Convolutional Neural Networks

Convolutional neural nets have demonstrated an amazing efficacy in the classification of highly correlated data with several layers of possible features (namely, in the application of image classification). Convolutional graph neural networks or Message Passing Neural Networks (MPNN) have recently shown promising preliminary results that hint at a potentially strong application for the discovery and classification of new materials.

Following notation in Gilmore, we can describe the different forms of these MPNNs by the specific definiton of their update functions. Hover over specific variables and functions below to see their definition.

$$ \texttip{m^{t+1}_v}{Message to be passed to the update function with vertex v, for iteration t+1} =\sum_{w\in \texttip{N(v)}{Set of vertices}}\texttip{M_t}{Message function for iteration t}(\texttip{h_v^t}{"Hidden" state associated with vertex v, for iteration t}, \texttip{h_w^t}{"Hidden" state associated with vertex w, for iteration t},\texttip{e_{vw}}{Edge vector associated with the connection between vertices v and w}) $$ $$ \texttip{h_v^{t+1}}{"Hidden" state associated with vertex v, for iteration t}= \texttip{U_t}{Vertex update function for iteration t}(\texttip{h_v^t}{"Hidden" state associated with vertex v, for iteration t}, \texttip{m_v^{t+1}}{Message to be passed to the update function with vertex v, for iteration t+1}) $$
where \(M_t\) is the message function and \(U_t\) is the vertex update function, which are to be specified in specific implementations.

A relevant and recent implementation of a specfic architecture of a GCNN by Tian Xie and Jeffrey C. Grossman can be found at the following github repository, corresponding with the paper Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties. Here, the authors designed and impelemented a graph convolutional neural network that could be trained to predict general properties of crystalline structures. This was achieved by converting input .cif files to graphs (in an internal format) and then passing these graphs off to the network with labels (relevant to the desired predictions) for training.

Crystal Graphs

The first step in the application of GCNNs to crytalline structures is the conversion of such structures into relevant and unique graph representations. This step requires the designer to make several decisions including what defines an edge in the crystal graph, as well as what feature vectors both edges and nodes should inherit.

A common file format already used to work with crystalline structures is the Crystallographic Information File standard, saved with the extension '.cif'. As such, the data handling portion of Xie and Grossman's CGCNN takes as input CIF files and outputs an internal graph structure (in relation to the recently implemented pytorch_geometric graph structure).

Ex: CIF to Graph Converter

A simple python function that converts '.cif' files to a torch_geometric graph can be found here.

This function utilizes the '.get_neighbor_list()' method inherent in pymatgen structure classes. Then, within a defined radius, edge tuples using the index of the '.cif' file are created, with node positions automatically associated with these nodes and distances automatically associated with edges as features.

The function takes as input a '.cif' file's address as a string and returns a torch_geometric graph object with the attributes listed above.

CGCNN Architecture

The architecture of the CGCNN is relatively simple: it consists of an input layer that recieves the chosen graph representations; \(R\) convolutional layers; a hidden layer \(L_1\)(fully connected); one pooling layer; and then another hidden layer \(L_2\); which feeds into the output layer, which is a single node.

The following example architecture follows that of Xie and Grossman's CGCNN.

CGCNN Architecture

Convolutional Layers

The convolutional layers in the CGCNN model generally utilized one of two convolutional functions, described below. The equation numbering of the article is included below for reference.

\[\vec{v}^{(t+1)}_i=g\left[\left(\sum_{j,k}\vec{v}_j^{(t)}\oplus\vec{v}_{(i,j)_k}\right)\mathbf{W}^{(t)}_{c} +\vec{v}^{(t)}_{s}\mathbf{W}^{(t)}_{s}+\vec{b}^{(t)}\right] \quad (4)\] \[\vec{v}^{(t+1)}_i=\vec{v}^{(t)}_i+\sum_{j,k}\sigma\left[\vec{z}_{(i,j)_k}^{(t)}\mathbf{W}_f+\vec{b}_f^{(t)}\right] \odot g\left[\vec{z}_{(i,j)_k}^{(t)}\mathbf{W}_s+\vec{b}^{(t)}_{s}\right]\quad (5)\]

Pooling Layers

The pooling layers in the CGCNN model then followed from the choice of convolutional function, with the two used variants described below. However, both were forms of normalized summation. These normalized mean layers have the advantage of drawing information from the whole of the given cell, but comes with the disadvantage of effectively normalizing the infleunce of each site.

\[\texttip{\vec{v}_c}{Learned crystal feature vector after pooling layer}=\texttip{\sum_{i,t}}{Sum over all atom feature vectors i and convolutional layers t}\text{Softmax}(\texttip{\mathbf{W}_t}{Learned weight matrix for t-th convolutional layer} \texttip{\vec{v}_i^{(t)}}{Atom i feature vector for t-th convolutional layer}+\texttip{\vec{b}_t}{Learned bias vector for t-th convolutional layer}) \quad\text{for}\ (4)\] \[\texttip{\vec{v}_c}{Learned crystal feature vector after pooling layer}=\frac{1}{\texttip{N}{Number of atoms in crystal}}\texttip{\sum_{i}}{Sum over i atoms in crystal} \texttip{\vec{v}_i^{(R)}}{Final learned feature vector after R convolutional layers for atom i}\quad\quad\text{for}\ (5)\]
Note that the first pooling function sums over the output feature vectors from every convolutional layer, whereas the second choice only need sum over the last convolutional layers learned vectors.

Classification of Point Defects

An interesting question in this avenue is whether these CGCNN models can distiguish and classify defect structures from perfect crystalline structures.

Generating Data

Defect data was generated using certain utilities in the python package pymatgen, as well as modified versions of custom defined python functions created by Debajit Chakraborty.

For simplicity, we consider only point defects initially. These defects were created for randomly chosen cells of the C2DB materials database. The point defect structures considered are covered below.

Defect Structures

Example substitution point defect Example vacancy point defect Example antisite point defect Example interstitial point defect

Four common types of point defects in crystalline solids are interstitial, vacancy, antisite, and substitution defects. Point defects are specified at a single point and as such may be isolated to one point of the structure. Given their simple and localized nature, they are a natural starting point for training.

Pooling Layer

Due to the localized nature of the point defects, the normalized sum function previously utilized in the CGCNN paper proved very ineffective, as the contributing sites defining the defect were drowned out by the influence of the several other points. To this end, a max pooling layer was implemented instead, defined below.

A comparison of results for both the vanilla pooling layer (mean) and the new max pooling layer is provided in the test results below.

Test Results

Best test results for the CGCNN being trained as a classifier for single sets of defects (labelled with 1) and perfect supercells of the same size (labelled with 0) is reported in the table below. The data was partitioned into a 50/50 data set with as many defective supercells as perfect supercells for testing.

Training was done for 30 epochs with a radius of 8 (units?) and a maximum number of considered neighbors of 12. Feature vectors for the atoms was inhereted from the 'atom_init.json' file of the sample data provided in the CGCNN github repository.

Best Test Results
Defect Cell Size
2x2 6x6
Vacancy 62.5% 84% 49.2% 91.8%
Antisite 68.8% 97.3% 59.8% 94.5%
Substitution\(^{(H/Li)}\) 99.2% 99.6% 53.5% 96.5%
Interstitial\(^{*(H/Li}_{\ Ag/Br)}\) Na Na 50.8% 96.9%
Pooling Layer Mean Max Mean Max

Clearly, the CGCNN performed better with the replaced max pooling layer.


A book on machine learning applied to the quantum mechanical structures of molecules: Machine Learning Meets Quantum Physics.

Message Passing Neural Networks (MPNN)

For an example project, see this guide by Alexander Kensert to using keras to construct and implement an MPNN that predicts molecular properties. Furthermore, message passing networks are easily implmentable in pytorch geometric.

Other Relevant Tools

Below are, for reference, resources related to computational tools facilitating the digital handling of solid state structures.

Pymatgen - File Support within Python

A large python package for use in matrials research. Allows the creation of python objects representing unit cells of crystals and single molecules. A workshop for pymatgen is available on The Materials Project. Pymatgen also has a Symmetry Analyzer Module, which may be of some relevance for defects.

VASP - Material Property Calculator

A fortran library that utilizes Density Functional Theory (DFT) to predict quantitative properties of solid state materials, namely band structures, wave functions, and chemical properties. VASP generally requires High Performance Computers (HPC) to perform non-trivial tasks, and further requires a license to obtain.

VESTA - Visualization & File Support

VESTA may be used to view, create, and convert between file formats that represent solid state structures. It may be downloaded here.