Chenyu Tangsdgs
8/25/2025 - 3:09 PM

lock Deep Learning-Based Microstructural Image Analysis

Deep Learning-Based Microstructural Image Analysis πŸ“– Overview This repository provides the official implementation of the paper: "Deep Learning-Based Image Classification for Microstructural Analysis in Computational Materials Science".

We introduce MorphoTensor, a novel deep generative framework that integrates physical, geometric, and topological priors for microstructural image analysis. Our method combines hierarchical tensorial embeddings, latent spatial warping, and topology-aware latent refinement to achieve accurate, interpretable, and robust microstructure classification.

✨ Features Hierarchical Tensor Convolutions: Capture anisotropy, directionality, and multiscale spatial patterns. Latent Spatial Warping: Models geometric irregularities such as grain boundaries and phase interfaces. Topology-Aware Latent Refinement: Preserves Betti numbers and persistent homology structures. Robustness Across Datasets: Validated on HEDM, HTEM, Weld Seam, MID datasets, and EBSD phase-field simulations. Improved Accuracy: Outperforms CNN, VAE, ConvNeXt, ViT, and other baselines across classification benchmarks​. πŸ“Š Results HEDM Dataset: 81.47% accuracy (+2.22% over ConvNeXt). HTEM Dataset: 90.36% accuracy (+1.45% over ConvNeXt). WELD SEAM Dataset: 93.15% accuracy, F1 score 92.51%. MID Dataset: 81.47% accuracy, significantly higher than DeiT (+5.34%). (see Tables 2 & 3 in the paper for detailed benchmarking results​) πŸ“‚ Repository Structure bash

β”œβ”€β”€ data/ # Example datasets and preprocessing scripts β”œβ”€β”€ models/ # MorphoTensor architecture and training modules β”œβ”€β”€ experiments/ # Benchmark experiments and configs β”œβ”€β”€ results/ # Sample outputs and evaluation metrics └── README.md # Project documentation πŸš€ Getting Started Prerequisites Python 3.9+ PyTorch (>=1.12) CUDA 11.6 / 12.1 Other dependencies are listed in requirements.txt. Installation bash

git clone https://github.com/yourusername/microstructure-analysis.git cd microstructure-analysis pip install -r requirements.txt Training bash

python train.py --dataset HEDM --epochs 100 --batch_size 256 Evaluation bash

python evaluate.py --checkpoint checkpoints/model_best.pth πŸ“Œ Citation If you use this work, please cite:

mathematica

@article{Tan2025Microstructure, title={Deep Learning-Based Image Classification for Microstructural Analysis in Computational Materials Science}, author={Chenyu Tan}, journal={Frontiers in Materials}, year={2025} } πŸ™Œ Acknowledgements This work integrates ideas from computational materials science, topology, and deep learning. Special thanks to the supporting institutions and collaborators mentioned in the paper​.

# -*- coding: utf-8 -*-
# MorphoTensor (PyTorch) β€” hierarchical tensor convs + spectral upsample + latent warping + (optional) topology proxy
# Author: Legend Co., Ltd.
# References:
#   - Hierarchical tensor filters & spectral upsample (Eq.5-7)
#   - Latent warping & regularizers (Eq.10-14)
#   - Topology-aware proxy (Eq.15-16)
#   - VAE objective & multiscale reconstruction (Eq.21-22)

from typing import Tuple, Dict, Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# Utils
# -----------------------------

def _make_coordinate_grid(h: int, w: int, device, dtype=torch.float32):
    ys = torch.linspace(-1.0, 1.0, h, device=device, dtype=dtype)
    xs = torch.linspace(-1.0, 1.0, w, device=device, dtype=dtype)
    grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
    grid = torch.stack([grid_x, grid_y], dim=-1)
    return grid

def _fftshift2(x: torch.Tensor) -> torch.Tensor:
    h, w = x.shape[-2:]
    return torch.roll(torch.roll(x, shifts=h//2, dims=-2), shifts=w//2, dims=-1)

def _ifftshift2(x: torch.Tensor) -> torch.Tensor:
    h, w = x.shape[-2:]
    return torch.roll(torch.roll(x, shifts=-(h//2), dims=-2), shifts=-(w//2), dims=-1)

def spectral_upsample(x: torch.Tensor, scale: int = 2) -> torch.Tensor:
    if scale == 1:
        return x
    B, C, H, W = x.shape
    X = torch.fft.fft2(x, dim=(-2, -1))
    Xc = _fftshift2(X)
    Hn, Wn = H * scale, W * scale
    pad = Xc.new_zeros((B, C, Hn, Wn), dtype=Xc.dtype)
    hs, ws = H // 2, W // 2
    pad[:, :, Hn//2 - hs:Hn//2 + (H - hs), Wn//2 - ws:Wn//2 + (W - ws)] = Xc
    pad = _ifftshift2(pad)
    y = torch.fft.ifft2(pad, dim=(-2, -1)).real
    y = y * (scale * scale)
    return y

def conv2d_same(x, weight, bias=None, stride=1):
    kh, kw = weight.shape[-2:]
    pad_h = (kh - 1) // 2
    pad_w = (kw - 1) // 2
    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
    return F.conv2d(x, weight, bias=bias, stride=stride)

# -----------------------------
# Directional (tensor) filter bank (Eq.6)
# -----------------------------

def make_gabor_bank(r: int, ksize: int = 7, sigma: float = 2.0, lam: float = 3.0, gamma: float = 0.5, psi: float = 0.0, device="cpu"):
    half = (ksize - 1) / 2.0
    ys, xs = torch.meshgrid(torch.linspace(-half, half, ksize, device=device),
                            torch.linspace(-half, half, ksize, device=device), indexing="ij")
    bank = []
    for a in range(r):
        theta = math.pi * a / r
        x_theta =  xs * math.cos(theta) + ys * math.sin(theta)
        y_theta = -xs * math.sin(theta) + ys * math.cos(theta)
        gb = torch.exp(-(x_theta**2 + (gamma**2) * y_theta**2) / (2 * sigma**2)) * torch.cos(2 * math.pi * x_theta / lam + psi)
        gb = gb / (gb.abs().sum() + 1e-8)
        bank.append(gb)
    return torch.stack(bank, dim=0)

class DirectionalConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, r=8, ksize=7):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.r = r
        self.ksize = ksize
        self.coeff = nn.Parameter(torch.randn(out_ch, in_ch, r) * (1.0 / math.sqrt(in_ch * r)))
        bank = make_gabor_bank(r=r, ksize=ksize)
        self.register_buffer("basis", bank)
        self.bias = nn.Parameter(torch.zeros(out_ch))

    def forward(self, x):
        weight = torch.einsum("oir,rkl->oikl", self.coeff, self.basis)
        return conv2d_same(x, weight, bias=self.bias)

# -----------------------------
# Instance-wise modulation (Eq.9)
# -----------------------------

class AdaAffine(nn.Module):
    def __init__(self, feat_ch: int, z_dim: int):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(z_dim, 2 * feat_ch),
            nn.SiLU(),
            nn.Linear(2 * feat_ch, 2 * feat_ch)
        )

    def forward(self, feat, z):
        B, C, _, _ = feat.shape
        params = self.mlp(z)
        gamma, beta = params.chunk(2, dim=1)
        gamma = gamma.view(B, C, 1, 1)
        beta  = beta.view(B, C, 1, 1)
        return gamma * feat + beta

# -----------------------------
# Latent Spatial Warping (Eq.10-14)
# -----------------------------

class LatentWarper2D(nn.Module):
    def __init__(self, in_ch: int, delta: float = 0.25):
        super().__init__()
        self.delta = delta
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 2, 3, padding=1)
        )

    @torch.no_grad()
    def _make_base_grid(self, H, W, device):
        return _make_coordinate_grid(H, W, device=device)

    def forward(self, feat: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        B, C, H, W = feat.shape
        disp = torch.tanh(self.net(feat)) * self.delta
        base = self._make_base_grid(H, W, feat.device).unsqueeze(0).repeat(B, 1, 1, 1)
        grid = base + disp.permute(0, 2, 3, 1)
        warped = F.grid_sample(feat, grid, mode="bilinear", padding_mode="border", align_corners=True)

        regs = {}
        dx = disp[:, 0:1]
        dy = disp[:, 1:1+1]
        def ddx(t): return F.pad(t[:, :, :, 1:] - t[:, :, :, :-1], (0,1,0,0))
        def ddy(t): return F.pad(t[:, :, 1:, :] - t[:, :, :-1, :], (0,0,0,1))
        dxdx = ddx(dx); dxdy = ddy(dx)
        dydx = ddx(dy); dydy = ddy(dy)
        J11 = 1 + dxdx
        J12 = dxdy
        J21 = dydx
        J22 = 1 + dydy
        frob = (J11 - 1)**2 + J12**2 + J21**2 + (J22 - 1)**2
        regs["l_warp"] = frob.mean()
        detJ = J11 * J22 - J12 * J21
        regs["l_det"] = ((detJ - 1.0) ** 2).mean()
        return warped, regs

# -----------------------------
# MorphoTensor Decoder / Generator (Eq.5-7)
# -----------------------------

class MTBlock(nn.Module):
    def __init__(self, in_ch, out_ch, z_dim, r=8, ksize=7):
        super().__init__()
        self.dirconv = DirectionalConv2d(in_ch, out_ch, r=r, ksize=ksize)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act = nn.GELU()
        self.adain = AdaAffine(out_ch, z_dim)

    def forward(self, x, z):
        y = self.dirconv(x)
        y = self.norm(y)
        y = self.act(y)
        y = self.adain(y, z)
        return y

class MorphoTensorDecoder(nn.Module):
    def __init__(self, z_dim=128, base_hw: Tuple[int,int]=(16,16), base_ch=64, layers=(64,64,32,16,8), r=8):
        super().__init__()
        self.z_dim = z_dim
        self.base_h, self.base_w = base_hw
        self.fc = nn.Linear(z_dim, base_ch * self.base_h * self.base_w)
        chs = [base_ch] + list(layers)
        self.blocks = nn.ModuleList([MTBlock(chs[i], chs[i+1], z_dim, r=r) for i in range(len(chs)-1)])
        self.to_img = nn.Conv2d(chs[-1], 1, 3, padding=1)
        self.warper = LatentWarper2D(in_ch=chs[-1], delta=0.25)

    def forward(self, z):
        B = z.size(0)
        x = self.fc(z).view(B, -1, self.base_h, self.base_w)
        feats = None
        for i, blk in enumerate(self.blocks):
            x = blk(x, z)
            feats = x
            if i < len(self.blocks) - 1:
                x = spectral_upsample(x, scale=2)
        warped_feats, regs = self.warper(feats)
        y = self.to_img(warped_feats)
        out = torch.sigmoid(y)
        return out, regs

# -----------------------------
# Encoder (Eq.20)
# -----------------------------

class Encoder(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.GELU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.GELU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.GELU(),
            nn.AdaptiveAvgPool2d((4, 4))
        )
        self.fc_mu = nn.Linear(256 * 4 * 4, z_dim)
        self.fc_lv = nn.Linear(256 * 4 * 4, z_dim)

    def forward(self, u):
        h = self.net(u)
        h = h.flatten(1)
        mu = self.fc_mu(h)
        lv = self.fc_lv(h)
        return mu, lv

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

# -----------------------------
# Topology-aware proxy loss (Eq.15-16, proxy)
# -----------------------------

class EulerProxyLoss(nn.Module):
    def __init__(self, target_chi: Optional[float] = None, weight: float = 1.0):
        super().__init__()
        self.target_chi = target_chi
        self.weight = weight

    def forward(self, u01: torch.Tensor) -> torch.Tensor:
        if self.target_chi is None:
            return u01.new_zeros(())
        p = torch.clamp(u01, 1e-4, 1-1e-4)
        p00 = p[:, :, :-1, :-1]
        p01 = p[:, :, :-1, 1:]
        p10 = p[:, :, 1:, :-1]
        p11 = p[:, :, 1:, 1:]
        V = p00 + p01 + p10 + p11
        E = p00*p01 + p00*p10 + p11*p01 + p11*p10
        F = p00*p01*p10*p11
        chi = (V.sum((2,3)) - E.sum((2,3)) + F.sum((2,3)))
        loss = ((chi - self.target_chi)**2).mean() * self.weight
        return loss

# -----------------------------
# Full VAE model
# -----------------------------

class MorphoTensorVAE(nn.Module):
    def __init__(self, z_dim=128, topo_target: Optional[float]=None, topo_w=0.0, warp_w=0.0, det_w=0.0):
        super().__init__()
        self.enc = Encoder(z_dim=z_dim)
        self.dec = MorphoTensorDecoder(z_dim=z_dim)
        self.topo_loss = EulerProxyLoss(target_chi=topo_target, weight=topo_w)
        self.warp_w = warp_w
        self.det_w = det_w

    def forward(self, x) -> Dict[str, torch.Tensor]:
        mu, lv = self.enc(x)
        z = reparameterize(mu, lv)
        recon, regs = self.dec(z)

        out = {
            "recon": recon,
            "mu": mu,
            "logvar": lv
        }
        out["l_rec"] = F.mse_loss(recon, x)
        kl = -0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp())
        out["l_kl"] = kl
        out["l_topo"] = self.topo_loss(recon)
        out["l_warp"] = regs.get("l_warp", torch.zeros_like(kl)) * self.warp_w
        out["l_det"]  = regs.get("l_det",  torch.zeros_like(kl)) * self.det_w
        out["loss"] = out["l_rec"] + out["l_kl"] + out["l_topo"] + out["l_warp"] + out["l_det"]
        return out

# -----------------------------
# Example
# -----------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    B, H, W = 2, 256, 256
    x = torch.rand(B, 1, H, W, device=device)

    model = MorphoTensorVAE(z_dim=128, topo_target=None, topo_w=0.0, warp_w=1e-3, det_w=1e-3).to(device)
    out = model(x)
    print({k: float(v.detach().cpu()) if torch.is_tensor(v) and v.dim()==0 else (v.shape if torch.is_tensor(v) else v)
           for k,v in out.items() if k != "recon"})