Skip to content

Support SPLADE style sparse embeddings #129

@jordanparker6

Description

@jordanparker6

Feature request

It would be great to have a single server to calculate the dense / sparse embeddings fro hybrid search.

Having a sparse endpoint that would calculate the sparse embeddings would do this.

Motivation

I want to be able to use the same TEI server for my dense and sparse embedding calculations with hybrid search.

Your contribution

I am happy to help by working on a PR.

The following code from llama_index may be of use for processing the incoming vectors and returning it in a sparse format for dbs like Qdrant.

def default_sparse_encoder(model_id: str) -> SparseEncoderCallable:
    try:
        import torch
        from transformers import AutoModelForMaskedLM, AutoTokenizer
    except ImportError:
        raise ImportError(
            "Could not import transformers library. "
            'Please install transformers with `pip install "transformers[torch]"`'
        )

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(model_id)
    if torch.cuda.is_available():
        model = model.to("cuda")

    def compute_vectors(texts: List[str]) -> Tuple[List[List[int]], List[List[float]]]:
        """
        Computes vectors from logits and attention mask using ReLU, log, and max operations.
        """
        # TODO: compute sparse vectors in batches if max length is exceeded
        tokens = tokenizer(
            texts, truncation=True, padding=True, max_length=512, return_tensors="pt"
        )
        if torch.cuda.is_available():
            tokens = tokens.to("cuda")

        output = model(**tokens)
        logits, attention_mask = output.logits, tokens.attention_mask
        relu_log = torch.log(1 + torch.relu(logits))
        weighted_log = relu_log * attention_mask.unsqueeze(-1)
        tvecs, _ = torch.max(weighted_log, dim=1)

        # extract the vectors that are non-zero and their indices
        indices = []
        vecs = []
        for batch in tvecs:
            indices.append(batch.nonzero(as_tuple=True)[0].tolist())
            vecs.append(batch[indices[-1]].tolist())

        return indices, vecs

    return compute_vectors

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions