mtopic.tl.sMTM_GPU#
- class mtopic.tl.sMTM_GPU(mdata, n_topics=20, radius=0.05, n_neighbors=None, seed=2291, spatial_key='coords', verbose=True)#
GPU-accelerated Spatial Multimodal Topic Model.
This class implements a CUDA-accelerated version of the Spatial Multimodal Topic Model (sMTM) for analyzing single-cell spatial data across multiple modalities. The model captures spatial relationships by constructing a spatial neighborhood graph and uses Variational Inference (VI) to identify spatially-aware topics. The model is mathematically equivalent to
sMTMbut executes E-step and M-step updates on the GPU using sparse CSR tensors and batched neighborhood computations for substantial speedups on large datasets.- Parameters:
mdata (muon.MuData) – A MuData object containing multimodal single-cell spatial data, including spatial coordinates in the obsm attribute.
n_topics (int, optional) – Number of topics to infer. Each topic represents a distinct spatial pattern across features and modalities. Default is 20.
radius (float, optional) – Radius for constructing a spatial neighborhood graph. Used if n_neighbors is None. Default is 0.05.
n_neighbors (int, optional) – Number of neighbors to consider when constructing the spatial neighborhood graph. Overrides radius if set. Default is None.
seed (int, optional) – Random seed for reproducibility. Ensures consistent initialization and results. Default is 2291.
spatial_key (str, optional) – Key in the obsm attribute of MuData specifying spatial coordinates. Default is ‘coords’.
verbose (bool, optional) – If True, displays a progress bar during training. Default is True.
- Variables:
n_topics (int) – Number of topics initialized in the model.
radius (float) – Radius used for spatial neighborhood graph construction.
seed (int) – Random seed used for initializing the model.
rng (numpy.random.Generator) – Random number generator initialized with the seed.
device (str) – Compute device used by the model (always
"cuda").spatial_key (str) – Key for accessing spatial coordinates in MuData.
modalities (list) – List of modalities in the dataset.
features (dict) – Dictionary of feature names for each modality.
barcodes (list) – List of barcodes corresponding to the samples.
n_obs (int) – Number of samples (observations) in the dataset.
n_var (dict) – Dictionary with the number of features per modality.
coords_scaled (numpy.ndarray) – Spatial coordinates normalized to [0, 1] (numpy).
coords (torch.Tensor) – Scaled spatial coordinates as a CUDA tensor.
neighborhood_dist (numpy.ndarray) – Distances between each sample and its neighbors (numpy, padded for radius mode).
neighborhood_graph (numpy.ndarray) – Indices of neighbors for each sample (numpy, padded for radius mode).
dist (torch.Tensor) – Neighbor distances as a CUDA tensor.
neigh (torch.Tensor) – Neighbor indices as a CUDA tensor.
gamma (torch.Tensor or numpy.ndarray) – Variational parameters for topic distributions.
lambda (dict) – Variational parameters for topics across modalities.
exp_E_log_beta (dict) – Expected log topic distributions.
- Methods:
- VI(n_iter=20, max_iter_d=100)#
- Perform Variational Inference (VI) to fit the model to the data.
All observations are processed in a single GPU batch each iteration; padding in the neighborhood arrays is handled internally so that neighborhoods of varying size (radius mode) contribute correctly to the spatial similarity term.
- Parameters:
n_iter (int, optional) – Number of iterations for the VI algorithm. Default is 20.
max_iter_d (int, optional) – Maximum number of iterations for the E-step in each VI update. Default is 100.
- Returns:
None
- Return type:
None
- Example:
model = mtopic.tl.sMTM_GPU(mdata, n_topics=20, radius=0.05) model.VI(n_iter=20)
- Example:
import mtopic # Load spatial multimodal single-cell data mdata = mtopic.read.h5mu("path/to/file.h5mu") # Initialize and train the model model = mtopic.tl.sMTM_GPU(mdata, n_topics=20, radius=0.05) model.VI(n_iter=20)
- __init__(mdata, n_topics=20, radius=0.05, n_neighbors=None, seed=2291, spatial_key='coords', verbose=True)#
Methods
VI([n_iter, max_iter_d])__init__(mdata[, n_topics, radius, ...])