TL;DR

Vision-language models (as CLIP) are frequently used as zero-shot classifiers: Given a text prompt, one can find the prompt similarity to image embeddings. Prompt engineering is able to improve this zero-shot classification significantly, however, it’s time consuming. The CoOp method suggests to have a learnable prompt (trained with a single sample .i.e. one shot) and by that get similar performance to human crafted prompts. By using 16 samples, they able to improve human created prompts by +15%.

teaser

Method

The authors suggest the “Context Optimization” (CoOp) method: Replacing the prompt’s context words with learnable vectors, which are initialized with either random values. Two implementations are provided to handle tasks of different natures: Unified context shares the same context with all classes and works well on most categories; At inference, feed the text encoder with a single context vector, and get a prediction of the trained classes
class-specific context learns a specific set of context tokens for each class and is found to be more suitable for some fine-grained categories. At inference, feed the text encoder with “n_cls” of vectors and get a prediction of each class.

During training, the prediction errors is minimized using the cross-entropy loss with respect to the learnable context vectors while keeping the pre-trained parameters (CLIP) fixed: Gradients are back-propagated all the way through the text encoder.

method

Unified context vs class-specific context from the paper’s github:

# random initialization
if cfg.TRAINER.COOP.CSC:
    print("Initializing class-specific contexts")
    ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
else:
    print("Initializing a generic context")
    ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
        

Limitations

  • It’s not zero-shot per say, still requires training of the context vectors.
  • The task of “whether the class exists in the image” is not straight forward using this method: the method trains to classify the image to one of a closed set of classes, and not able to indicate when non of the classes are present in the image.

Resource

Arxiv: published in International Journal of Computer Vision (IJCV)

GitHub