Saturday, February 17, 2024

Quick Note: Contrast Learning in DSSM

Contrast learning is a way to setup the loss function. At point-wise loss function, negative samples and positive samples are considered separately. At Contrast Learning, the loss function (aka pair-wise loss) is constructed by taking 3 items: 2 positively related item and 1 unrelated item. The cosine similarity of the 2 positively related items should be large, and the similarity of the unrelated pair item should be small. The difference of the two similarity values is used as the loss function - the further apart of the two values the better. Since the optimization step is to reduce the loss, the training will lead it to fit both samples.

To make the positive sample and the negative sample separate better, usually a desired distance m is defined. If the difference is larger than m, it is considered as no loss. If the difference is less than m, loss will be the difference.

So the loss function is written as 

    Loss(a, b_positive, b_negative) = max{0, cos(a, b_negative) - cos(a, b_positive) + m}

It is also possible to write the function in logistic loss:

    Loss(a, b_positive, b_negative) = log(1 + exp[ sigma * (cos(a, b_negative) - cos(a, b_positive)) ] )


List-wise loss is a similar idea of the pairwise loss function, but considers more samples in the same loss function. Each a training record consists 1 pair of positive samples (a, b) pairing with the input, and multiple negative samples (b_neg_1, b_neg_2, etc) , and take cosine similarity between the pairs: cos(a,b), cos(a, b_neg_1), cos(a, b_neg_2), etc. Put all results in a Softmax function. Let the expected result to be (1, 0, 0, 0, ... ). Train it with CrossEntropyLoss between the result and the expected result.


Note: When trying point-wise loss function (that is training negative and positive samples separately), the ratio of the amount of positive samples and negative samples should be from 1:2 to 1:3,


Reference:

https://www.youtube.com/watch?v=2Mc10LZ-DB0 (voiced in Chinese)

No comments: