Enhancing Named Entity Recognition for Immunology and Immune-Mediated Disorders 📖 Overview This repository implements and reproduces the key methods from the paper “Enhancing Named Entity Recognition for Immunology and Immune-Mediated Disorders.” The work proposes a comprehensive NER framework tailored for the biomedical and immunology domains, integrating SpanStructuredEncoder, Contextual Constraint Decoding (CCD), and Graph-Guided Context Propagation to improve recognition accuracy and generalization in low-resource settings.
🚀 Key Features Hierarchical Span Representation (SpanStructuredEncoder): Aggregates multi-level semantic embeddings for robust entity boundary detection. Contextual Constraint Decoding (CCD): Incorporates global semantic consistency and ontology-based constraints for coherent prediction. Type-Aware Coherence Regularization: Encourages entity embeddings of the same type to maintain semantic uniformity, reducing ambiguity. Cross-Dataset Robustness: Outperforms multiple state-of-the-art baselines on Curia Simulation, Waymo Open, ApolloScape, and NGSIM datasets. 🧠 Model Architecture The proposed framework consists of three major components (Section 3 of the paper):
SpanStructuredEncoder — multi-channel representation encoder for boundary and context aggregation. Graph-Guided Context Propagation — a graph attention mechanism capturing inter-entity semantic relations. Contextual Constraint Decoding (CCD) — a structured inference layer integrating type and ontology consistency. The diagrams on pages 9–13 (Figures 2–4) visualize the full architecture and component interactions.
📊 Experimental Results On Curia Simulation and Waymo Open, the model improves F1-scores by +2.7% and +1.9%, respectively. ApolloScape and NGSIM ablation studies validate each module’s contribution. The model maintains stable performance across multiple biomedical entity types (diseases, cells, molecules). ⚙️ Environment Setup bash
python >= 3.9 torch >= 2.0 transformers >= 4.30 scikit-learn numpy pandas matplotlib Install dependencies:
bash
pip install -r requirements.txt 🧩 Usage Prepare Data Place datasets (e.g., NCBI-Disease, BC5CDR, Curia Simulation) under the data/ directory.
Train the Model
bash
python train.py --config configs/ner_config.yaml Evaluate the Model
bash
python evaluate.py --checkpoint checkpoints/best_model.pt Run Inference
bash
python predict.py --text "IL-6 and TNF-alpha are critical in immune-mediated inflammation." 📚 Datasets NCBI Disease Corpus Curia Simulation Dataset Waymo Open Dataset ApolloScape NGSIM Detailed preprocessing and annotation guidelines are provided in Section 4.1 of the paper.
📈 Visualization Figures 5–8 (pages 17–22) in the paper illustrate:
Comparison with SOTA methods. Ablation studies. Precision, Recall, and F1 metrics across entity types. 🔬 Future Work According to the paper’s conclusion:
Extend the model with self-supervised and cross-modal knowledge integration. Apply the framework to related biomedical NLP tasks, such as drug-target prediction. Explore the synergy with large language models (LLMs) for better contextual reasoning. 🧾 Citation If you use this repository or reference the original paper, please cite:
bibtex
@article{chen2025enhancing, title={Enhancing Named Entity Recognition for Immunology and Immune-Mediated Disorders}, author={Jinhan Chen and Mingxiang Sun and Yuhong Wang}, journal={Frontiers in Immunology}, year={2025} }
# model.py
# Minimal research-grade implementation for span-based NER with
# SpanStructuredEncoder, Graph-Guided Context Propagation, and CCD decoding.
# Author: Legend Co., Ltd. (Rebecca)
# Python >=3.9, torch >=2.0, transformers >=4.30
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from transformers import AutoModel, AutoConfig
except Exception as e:
AutoModel, AutoConfig = None, None
print(
"[model.py] transformers is not installed. "
"Install with: pip install transformers"
)
# -----------------------------
# Utilities
# -----------------------------
def lengths_to_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
"""
lengths: (B,) lengths
returns: (B, T) boolean mask where True = valid
"""
b = lengths.size(0)
t = int(max_len or lengths.max().item())
range_ = torch.arange(t, device=lengths.device).unsqueeze(0).expand(b, t)
return range_ < lengths.unsqueeze(1)
def batched_index_select(values: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""
values: (B, T, H)
index: (B, K) indices into T
returns: (B, K, H)
"""
b, t, h = values.shape
k = index.size(1)
offset = torch.arange(b, device=values.device).unsqueeze(1) * t # (B,1)
flat = values.reshape(b * t, h)
flat_index = (index + offset).reshape(-1)
gathered = flat.index_select(0, flat_index)
return gathered.view(b, k, h)
# -----------------------------
# Config
# -----------------------------
@dataclass
class ModelConfig:
transformer: str = "bert-base-cased"
hidden_size: int = 768 # will be overwritten by transformer config if available
span_width: int = 8 # maximum span width (in tokens)
span_head_dim: int = 128
gcn_hidden: int = 256
gcn_layers: int = 2
dropout: float = 0.1
# decoding/ccd
non_overlap: bool = True
nms_threshold: float = 0.6 # IoU threshold for span suppression
topk_per_doc: int = 200 # keep top-K spans for graph + decoding
# -----------------------------
# SpanStructuredEncoder
# -----------------------------
class SpanStructuredEncoder(nn.Module):
"""
Multi-channel span representation:
- boundary embeddings (start/end)
- head attention pooling inside span
- width embedding
"""
def __init__(self, hidden_size: int, span_width: int, head_dim: int, dropout: float):
super().__init__()
self.hidden_size = hidden_size
self.span_width = span_width
self.dropout = nn.Dropout(dropout)
self.width_embed = nn.Embedding(span_width, hidden_size)
self.head_scorer = nn.Linear(hidden_size, 1)
self.proj = nn.Linear(hidden_size * 2 + hidden_size + hidden_size, head_dim)
self.layer_norm = nn.LayerNorm(head_dim)
def forward(
self,
token_reprs: torch.Tensor, # (B, T, H)
attention_mask: torch.Tensor, # (B, T) bool
spans: torch.Tensor # (B, S, 2) start,end inclusive
) -> torch.Tensor:
b, t, h = token_reprs.shape
s = spans.size(1)
start_idx = spans[..., 0] # (B,S)
end_idx = spans[..., 1] # (B,S)
span_len = end_idx - start_idx + 1 # (B,S)
# Clamp for width embedding
width_idx = torch.clamp(span_len - 1, min=0, max=self.span_width - 1)
width_vec = self.width_embed(width_idx) # (B,S,H)
# Boundary embeddings
start_vec = batched_index_select(token_reprs, start_idx) # (B,S,H)
end_vec = batched_index_select(token_reprs, end_idx) # (B,S,H)
# Head attention pooling within spans
# Build index grid (B,S,Wmax)
wmax = int(self.span_width)
rel = torch.arange(wmax, device=token_reprs.device).unsqueeze(0).unsqueeze(0) # (1,1,W)
pos = start_idx.unsqueeze(-1) + rel # (B,S,W)
pos_mask = pos <= end_idx.unsqueeze(-1) # valid positions within each span
pos = torch.clamp(pos, 0, t - 1)
# Gather token reps for each candidate position
span_tokens = batched_index_select(token_reprs, pos.view(b, -1)) # (B, S*W, H)
span_tokens = span_tokens.view(b, s, wmax, h) # (B,S,W,H)
attn_scores = self.head_scorer(span_tokens).squeeze(-1) # (B,S,W)
attn_scores = attn_scores.masked_fill(~pos_mask, -1e30)
attn = F.softmax(attn_scores, dim=-1).unsqueeze(-1) # (B,S,W,1)
head_vec = (span_tokens * attn).sum(dim=2) # (B,S,H)
# Concatenate channels
z = torch.cat([start_vec, end_vec, head_vec, width_vec], dim=-1) # (B,S,4H)
z = self.proj(self.dropout(z))
z = self.layer_norm(F.gelu(z))
return z # (B,S,head_dim)
# -----------------------------
# Graph-Guided Context Propagation
# -----------------------------
class GraphAttention(nn.Module):
"""
Single-head graph attention on spans.
"""
def __init__(self, dim_in: int, dim_out: int, dropout: float):
super().__init__()
self.lin_q = nn.Linear(dim_in, dim_out)
self.lin_k = nn.Linear(dim_in, dim_out)
self.lin_v = nn.Linear(dim_in, dim_out)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim_out)
def forward(self, span_reprs: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
span_reprs: (B, S, D)
adj: (B, S, S) 0/1 mask (include self)
"""
q = self.lin_q(span_reprs)
k = self.lin_k(span_reprs)
v = self.lin_v(span_reprs)
scale = 1.0 / math.sqrt(q.size(-1))
scores = torch.matmul(q, k.transpose(-1, -2)) * scale # (B,S,S)
scores = scores.masked_fill(~adj.bool(), -1e30)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(self.dropout(attn), v) # (B,S,D)
out = self.norm(out + q) # residual
return out
class GraphPropagation(nn.Module):
"""
Stacked GraphAttention layers; adjacency computed from:
- Span overlap
- Representation cosine-similarity (>= threshold)
"""
def __init__(self, dim: int, hidden: int, layers: int, dropout: float, cos_th: float = 0.3):
super().__init__()
self.layers = nn.ModuleList(
[GraphAttention(dim if i == 0 else hidden, hidden, dropout) for i in range(layers)]
)
self.cos_th = cos_th
@staticmethod
def _overlap_adj(spans: torch.Tensor) -> torch.Tensor:
"""
spans: (B,S,2)
returns: (B,S,S) bool matrix where True if spans overlap or identical.
"""
b, s, _ = spans.shape
start = spans[..., 0].unsqueeze(2) # (B,S,1)
end = spans[..., 1].unsqueeze(2) # (B,S,1)
start_t = spans[..., 0].unsqueeze(1) # (B,1,S)
end_t = spans[..., 1].unsqueeze(1) # (B,1,S)
# overlap if max(start_i, start_j) <= min(end_i, end_j)
max_start = torch.maximum(start, start_t)
min_end = torch.minimum(end, end_t)
return max_start <= min_end
def _cos_adj(self, span_reprs: torch.Tensor) -> torch.Tensor:
# cosine similarity adjacency
b, s, d = span_reprs.shape
x = F.normalize(span_reprs, p=2, dim=-1)
sim = torch.matmul(x, x.transpose(-1, -2)) # (B,S,S)
return sim >= self.cos_th
def forward(self, span_reprs: torch.Tensor, spans: torch.Tensor) -> torch.Tensor:
adj = self._overlap_adj(spans) | self._cos_adj(span_reprs)
# include self loops
eye = torch.eye(span_reprs.size(1), device=span_reprs.device).bool()
adj = adj | eye.unsqueeze(0)
h = span_reprs
for layer in self.layers:
h = layer(h, adj)
return h # (B,S,H)
# -----------------------------
# Span Proposal + Classification
# -----------------------------
class SpanProposal(nn.Module):
"""
Generates candidate spans up to max width and scores them.
"""
def __init__(self, max_width: int, hidden: int):
super().__init__()
self.max_width = max_width
self.scorer = nn.Linear(hidden, 1)
def enumerate_spans(self, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
lengths: (B,)
returns:
spans: (B,S,2) start,end
mask: (B,S) valid
"""
b = lengths.size(0)
t = int(lengths.max().item())
spans = []
for w in range(1, self.max_width + 1):
start = torch.arange(t, device=lengths.device).unsqueeze(0).expand(b, -1) # (B,T)
end = start + (w - 1)
spans.append(torch.stack([start, end], dim=-1)) # (B,T,2)
spans = torch.cat(spans, dim=1) # (B, T*W, 2)
# validity
valid = (spans[..., 1] < lengths.unsqueeze(1)) & (spans[..., 0] >= 0)
return spans, valid
def forward(self, span_reprs: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
"""
span_reprs: (B,S,D)
valid_mask: (B,S)
returns scores: (B,S)
"""
scores = self.scorer(span_reprs).squeeze(-1)
scores = scores.masked_fill(~valid_mask, -1e30)
return scores
class SpanClassifier(nn.Module):
def __init__(self, dim_in: int, num_labels: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(dim_in, num_labels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.classifier(self.dropout(x)) # (B,S,C)
# -----------------------------
# Contextual Constraint Decoding (CCD)
# -----------------------------
class CCDDecoder(nn.Module):
"""
Lightweight CCD:
- applies ontology/type masks (allowed label set per span if provided)
- enforces non-overlap via greedy NMS on span logits
"""
def __init__(self, nms_threshold: float = 0.6, non_overlap: bool = True):
super().__init__()
self.nms_threshold = nms_threshold
self.non_overlap = non_overlap
@staticmethod
def _iou(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
a: (N,2), b: (M,2), returns IoU matrix (N,M) for 1D intervals.
"""
start = torch.maximum(a[:, 0].unsqueeze(1), b[:, 0].unsqueeze(0))
end = torch.minimum(a[:, 1].unsqueeze(1), b[:, 1].unsqueeze(0))
inter = torch.clamp(end - start + 1, min=0)
len_a = (a[:, 1] - a[:, 0] + 1).unsqueeze(1)
len_b = (b[:, 1] - b[:, 0] + 1).unsqueeze(0)
union = len_a + len_b - inter
return inter / torch.clamp(union, min=1e-6)
def forward(
self,
spans: torch.Tensor, # (S,2)
label_logits: torch.Tensor, # (S,C)
allowed_mask: Optional[torch.Tensor] = None # (S,C) 1=allowed
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
keep_spans: (K,2)
keep_labels: (K,)
keep_scores: (K,)
"""
S, C = label_logits.shape
scores, labels = label_logits.softmax(-1).max(-1) # (S,), (S,)
if allowed_mask is not None:
masked_logits = label_logits.masked_fill(~allowed_mask.bool(), -1e30)
scores, labels = masked_logits.softmax(-1).max(-1)
order = torch.argsort(scores, descending=True)
keep = []
while order.numel() > 0:
i = order[0].item()
keep.append(i)
order = order[1:]
if self.non_overlap and order.numel() > 0:
iou = self._iou(spans[i:i+1], spans[order]).squeeze(0) # (N,)
order = order[iou <= self.nms_threshold]
keep = torch.tensor(keep, device=spans.device, dtype=torch.long)
return spans[keep], labels[keep], scores[keep]
# -----------------------------
# Full Model
# -----------------------------
class ImmunoSpanNER(nn.Module):
"""
End-to-end:
- Transformer encoder
- SpanStructuredEncoder
- GraphPropagation
- Span scoring + classification
- CCD decoding
"""
def __init__(self, num_labels: int, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
if AutoConfig is None:
raise RuntimeError("transformers not available. Please install it.")
base_cfg = AutoConfig.from_pretrained(cfg.transformer)
self.encoder = AutoModel.from_pretrained(cfg.transformer)
hidden = base_cfg.hidden_size
self.span_enc = SpanStructuredEncoder(
hidden_size=hidden,
span_width=cfg.span_width,
head_dim=cfg.span_head_dim,
dropout=cfg.dropout,
)
self.graph = GraphPropagation(
dim=cfg.span_head_dim,
hidden=cfg.gcn_hidden,
layers=cfg.gcn_layers,
dropout=cfg.dropout,
)
self.proposal = SpanProposal(cfg.span_width, cfg.gcn_hidden)
self.classifier = SpanClassifier(cfg.gcn_hidden, num_labels, cfg.dropout)
self.ccd = CCDDecoder(cfg.nms_threshold, cfg.non_overlap)
self.dropout = nn.Dropout(cfg.dropout)
def forward(
self,
input_ids: torch.Tensor, # (B,T)
attention_mask: torch.Tensor, # (B,T)
lengths: torch.Tensor, # (B,)
labels: Optional[List[List[Tuple[int,int,int]]]] = None,
ontology_masks: Optional[torch.Tensor] = None, # (B,S,C) optional per span
decode: bool = False,
):
"""
labels (optional training): list per batch; each element is list of (start,end,label_id)
"""
b, t = input_ids.size()
# 1) Transformer encoder
enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state # (B,T,H)
# 2) Enumerate spans (max width W)
spans, valid = self.proposal.enumerate_spans(lengths) # (B,S,2), (B,S)
S = spans.size(1)
# 3) SpanStructuredEncoder
span_repr = self.span_enc(enc, attention_mask.bool(), spans) # (B,S,D)
# 4) Graph-Guided Context Propagation
span_repr = self.graph(span_repr, spans) # (B,S,H)
# 5) Span scoring + classification
span_scores = self.proposal(span_repr, valid) # (B,S)
class_logits = self.classifier(span_repr) # (B,S,C)
outputs = {
"spans": spans,
"valid_mask": valid,
"span_scores": span_scores,
"class_logits": class_logits,
}
# 6) Loss (multi-label span classification with gold spans)
if labels is not None:
# Build gold target tensor: (B,S) span existence; (B,S,C) class one-hot for gold spans
device = input_ids.device
C = class_logits.size(-1)
gold_span = torch.zeros(b, S, device=device, dtype=torch.bool)
gold_class = torch.zeros(b, S, C, device=device)
for i, gold in enumerate(labels):
for (st, ed, lab) in gold:
# find exact matches among enumerated spans
eq = (spans[i, :, 0] == st) & (spans[i, :, 1] == ed)
if eq.any():
idx = torch.nonzero(eq, as_tuple=False).squeeze(-1)
gold_span[i, idx] = True
gold_class[i, idx, lab] = 1.0
# span existence (objectness) loss (binary)
span_obj_logits = span_scores
span_obj_target = gold_span.float()
obj_loss = F.binary_cross_entropy_with_logits(
span_obj_logits[valid], span_obj_target[valid]
)
# classification loss only for gold spans
if gold_class.sum() > 0:
cls_loss = F.cross_entropy(
class_logits[gold_span], gold_class[gold_span].argmax(-1)
)
else:
cls_loss = torch.tensor(0.0, device=device)
outputs["loss"] = obj_loss + cls_loss
# 7) CCD decoding at inference
if decode:
decoded = []
for i in range(b):
# top-K filter by span proposal
valid_i = valid[i]
scores_i = span_scores[i].masked_fill(~valid_i, -1e30)
topk = min(self.cfg.topk_per_doc, int(valid_i.sum().item()))
vals, idxs = torch.topk(scores_i, k=topk)
spans_i = spans[i, idxs] # (K,2)
logits_i = class_logits[i, idxs] # (K,C)
allowed_i = None
if ontology_masks is not None:
allowed_i = ontology_masks[i, idxs] # (K,C)
keep_spans, keep_labels, keep_scores = self.ccd(
spans_i, logits_i, allowed_i
)
decoded.append((keep_spans, keep_labels, keep_scores))
outputs["decoded"] = decoded
return outputs
# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
# Quick smoke test with random data
# NOTE: Run `pip install transformers torch` beforehand.
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = ModelConfig()
num_labels = 5 # e.g., {Disease, Cell, Molecule, Process, Other}
model = ImmunoSpanNER(num_labels=num_labels, cfg=cfg).to(device)
model.eval()
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(cfg.transformer)
texts = [
"IL-6 and TNF-alpha are critical in immune-mediated inflammation.",
"B cells interact with T cells in germinal centers during adaptive immunity.",
]
batch = tok(
texts, padding=True, truncation=True, return_tensors="pt", max_length=128
).to(device)
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
lengths = attention_mask.sum(dim=1)
with torch.no_grad():
out = model(
input_ids=input_ids,
attention_mask=attention_mask,
lengths=lengths,
labels=None,
ontology_masks=None,
decode=True,
)
# Print decoded predictions
for i, (spans_i, labels_i, scores_i) in enumerate(out["decoded"]):
print(f"\nDoc {i}:")
for (st, ed), lab, sc in zip(spans_i.tolist(), labels_i.tolist(), scores_i.tolist()):
# reconstruct substring for visualization
# (works on wordpieces best-effort)
print(f" span=({st},{ed}) label={lab} score={sc:.3f}")