TorchCRF
An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0
Install / Use
/learn @rikeda71/TorchCRFREADME
Torch CRF
Implementation of CRF (Conditional Random Fields) in PyTorch
Requirements
- python3 (>=3.6)
- PyTorch (>=1.0)
Installation
$ pip install TorchCRF
Usage
>>> import torch
>>> from TorchCRF import CRF
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> batch_size = 2
>>> sequence_size = 3
>>> num_labels = 5
>>> mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
>>> labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device) # (batch_size, sequence_size)
>>> hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
>>> crf = CRF(num_labels)
Computing log-likelihood (used where forward)
>>> crf.forward(hidden, labels, mask)
tensor([-7.6204, -3.6124], device='cuda:0', grad_fn=<ThSubBackward>)
Decoding (predict labels of sequences)
>>> crf.viterbi_decode(hidden, mask)
[[0, 2, 2], [4, 0]]
License
MIT
References
Related Skills
node-connect
339.5kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.9kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
339.5kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.9kCommit, push, and open a PR
