SkillAgentSearch skills...

Pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.

Install / Use

/learn @jmschrei/Pomegranate
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<img src="https://github.com/jmschrei/pomegranate/blob/master/docs/logo/pomegranate-logo.png" width=300>

Downloads

Note IMPORTANT: pomegranate v1.0.0 is a ground-up rewrite of pomegranate using PyTorch as the computational backend instead of Cython. Although the same functionality is supported, the API is significantly different. Please see the tutorials and examples folders for help rewriting your code.

ReadTheDocs | Tutorials | Examples

pomegranate is a library for probabilistic modeling defined by its modular implementation and treatment of all models as the probability distributions they are. The modular implementation allows one to easily drop normal distributions into a mixture model to create a Gaussian mixture model just as easily as dropping a gamma and a Poisson distribution into a mixture model to create a heterogeneous mixture. But that's not all! Because each model is treated as a probability distribution, Bayesian networks can be dropped into a mixture just as easily as a normal distribution, and hidden Markov models can be dropped into Bayes classifiers to make a classifier over sequences. Together, these two design choices enable a flexibility not seen in any other probabilistic modeling package.

Recently, pomegranate (v1.0.0) was rewritten from the ground up using PyTorch to replace the outdated Cython backend. This rewrite gave me an opportunity to fix many bad design choices that I made as a bb software engineer. Unfortunately, many of these changes are not backwards compatible and will disrupt workflows. On the flip side, these changes have significantly sped up most methods, improved and simplified the code, fixed many issues raised by the community over the years, and made it significantly easier to contribute. I've written more below, but you're likely here now because your code is broken and this is the tl;dr.

Special shout-out to NumFOCUS for supporting this work with a special development grant.

Installation

pip install pomegranate

If you need the last Cython release before the rewrite, use pip install pomegranate==0.14.8. You may need to manually install a version of Cython before v3.

Why a Rewrite?

This rewrite was motivated by four main reasons:

  • <b>Speed</b>: Native PyTorch is usually significantly faster than the hand-tuned Cython code that I wrote.
  • <b>Features</b>: PyTorch has many features, such as serialization, mixed precision, and GPU support, that can now be directly used in pomegranate without additional work on my end.
  • <b>Community Contribution</b>: A challenge that many people faced when using pomegranate was that they could not modify or extend it because they did not know Cython. Even if they did know Cython, coding in it is a pain that I felt each time I tried adding a new feature or fixing a bug or releasing a new version. Using PyTorch as the backend significantly reduces the amount of effort needed to add in new features.
  • <b>Interoperability</b>: Libraries like PyTorch offer an invaluable opportunity to not just utilize their computational backends but to better integrate into existing resources and communities. This rewrite will make it easier for people to integrate probabilistic models with neural networks as losses, constraints, and structural regularizations, as well as with other projects built on PyTorch.

High-level Changes

  1. General
  • The entire codebase has been rewritten in PyTorch and all models are instances of torch.nn.Module
  • This codebase is checked by a comprehensive suite of >800 unit tests calling assert statements several thousand times, much more than previous versions.
  • Installation issues are now likely to come from PyTorch for which there are countless resources to help out.
  1. Features
  • All models now have GPU support
  • All models now have support for half/mixed precision
  • Serialization is now handled by PyTorch, yielding more compact and efficient I/O
  • Missing values are now supported through torch.masked.MaskedTensor objects
  • Prior probabilities can now be passed to all relevant models and methods and enable more comprehensive/flexible semi-supervised learning than before
  1. Models
  • All distributions are now multivariate by default and treat each feature independently (except Normal)
  • "Distribution" has been removed from names so that, for example, NormalDistribution is now Normal
  • FactorGraph is now supported as first-class citizens, with all the prediction and training methods
  • Hidden Markov models have been split into DenseHMM and SparseHMM models which differ in how the transition matrix is encoded, with DenseHMM objects being significantly faster on truly dense graphs
  1. Differences
  • NaiveBayes has been permanently removed as it is redundant with BayesClassifier
  • MarkovNetwork has not yet been implemented
  • Constraint graphs and constrained structure learning for Bayesian networks has not yet been implemented
  • Silent states for hidden Markov models have not yet been implemented
  • Viterbi for hidden Markov models has not yet been implemented

Speed

Most models and methods in pomegranate v1.0.0 are faster than their counterparts in earlier versions. This generally scales by complexity, where one sees only small speedups for simple distributions on small data sets but much larger speedups for more complex models on big data sets, e.g. hidden Markov model training or Bayesian network inference. The notable exception for now is that Bayesian network structure learning, other than Chow-Liu tree building, is still incomplete and not much faster. In the examples below, torchegranate refers to the temporarily repository used to develop pomegranate v1.0.0 and pomegranate refers to pomegranate v0.14.8.

K-Means

Who knows what's happening here? Wild.

image

Hidden Markov Models

Dense transition matrix (CPU)

image

Sparse transition matrix (CPU)

image

Training a 125 node model with a dense transition matrix

image

Bayesian Networks

image image

Features

Note Please see the tutorials folder for code examples.

Switching from a Cython backend to a PyTorch backend has enabled or expanded a large number of features. Because the rewrite is a thin wrapper over PyTorch, as new features get released for PyTorch they can be applied to pomegranate models without the need for a new release from me.

GPU Support

All distributions and methods in pomegranate now have GPU support. Because each distribution is a torch.nn.Module object, the use is identical to other code written in PyTorch. This means that both the model and the data have to be moved to the GPU by the user. For instance:

>>> X = torch.exp(torch.randn(50, 4))

# Will execute on the CPU
>>> d = Exponential().fit(X)
>>> d.scales
Parameter containing:
tensor([1.8627, 1.3132, 1.7187, 1.4957])

# Will execute on a GPU
>>> d = Exponential().cuda().fit(X.cuda())
>>> d.scales
Parameter containing:
tensor([1.8627, 1.3132, 1.7187, 1.4957], device='cuda:0')

Likewise, all models are distributions, and so can be used on the GPU similarly. When a model is moved to the GPU, all of the models associated with it (e.g. distributions) are also moved to the GPU.

>>> X = torch.exp(torch.randn(50, 4)).cuda()
>>> model = GeneralMixtureModel([Exponential(), Exponential()]).cuda()
>>> model.fit(X)
[1] Improvement: 1.26068115234375, Time: 0.001134s
[2] Improvement: 0.168121337890625, Time: 0.001097s
[3] Improvement: 0.037841796875, Time: 0.001095s
>>> model.distributions[0].scales
Parameter containing:
>>> model.distributions[1].scales
tensor([0.9141, 1.0835, 2.7503, 2.2475], device='cuda:0')
Parameter containing:
tensor([1.9902, 2.3871, 0.8984, 1.2215], device='cuda:0')

Mixed Precision

pomegranate models can, in theory, operate in the same mixed or low-precision regimes as other PyTorch modules. However, because pomegranate uses more complex operations than most neural networks, this sometimes does not work or help in practice because these operations have not been optimized or implemented in the low-precision regime. So, hopefully this feature will become more useful over time.

>>> X = torch.randn(100, 4)
>>> d = Normal(covariance_type='diag')
>>>
>>> with torch.autocast('cuda', dtype=torch.bfloat16):
>>>     d.fit(X)

Serialization

pomegranate distributions are all instances of torch.nn.Module and so serialization is the same as any other PyTorch model.

Saving:

>>> X = torch.exp(torch.randn(50, 4)).cuda()
>>> model = GeneralMixtureModel([Exponential(), Exponential()], verbose=True)
>>> model.cuda()
>>> model.fit(X)
>>> torch.save(model, "test.torch")

Loading:

>>> model = torch.load("test.torch")

torch.compile

Note torch.compile is under active develop

View on GitHub
GitHub Stars3.5k
CategoryEducation
Updated2h ago
Forks596

Languages

Python

Security Score

100/100

Audited on Mar 29, 2026

No findings