diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java index 85abfbe4f..a028ac4b4 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java @@ -37,11 +37,15 @@ /** * Facilitates the generation of sentence vectors using * a sentence-transformers model converted to ONNX. + * + *

The model inputs follow the standard single-segment BERT + * encoding: {@code attention_mask} is {@code 1} for every real + * token and {@code token_type_ids} is {@code 0} throughout. */ public class SentenceVectorsDL extends AbstractDL { /** - * Instantiates a {@link SentenceVectorsDL sentence detector} using ONNX models. + * Instantiates a {@link SentenceVectorsDL sentence vector generator} using ONNX models. * * @param model The file name of a sentence vectors ONNX model. * @param vocabulary The file name of the vocabulary file for the model. @@ -54,7 +58,7 @@ public SentenceVectorsDL(final File model, final File vocabulary) env = OrtEnvironment.getEnvironment(); session = env.createSession(model.getPath(), new OrtSession.SessionOptions()); - vocab = loadVocab(new File(vocabulary.getPath())); + vocab = loadVocab(vocabulary); tokenizer = createTokenizer(vocab); } @@ -63,6 +67,7 @@ public SentenceVectorsDL(final File model, final File vocabulary) * Generates vectors given a sentence. * * @param sentence The input sentence. + * @return The sentence vector. * * @throws OrtException Thrown if an error occurs during inference. */ @@ -72,38 +77,61 @@ public float[] getVectors(final String sentence) throws OrtException { final Map inputs = new HashMap<>(); - inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids()), - new long[] {1, tokens.ids().length})); - - inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env, - LongBuffer.wrap(tokens.mask()), new long[] {1, tokens.mask().length})); + try { + inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids()), + new long[] {1, tokens.ids().length})); - inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env, - LongBuffer.wrap(tokens.types()), new long[] {1, tokens.types().length})); + inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env, + LongBuffer.wrap(tokens.mask()), new long[] {1, tokens.mask().length})); - final float[][][] v = (float[][][]) session.run(inputs).get(0).getValue(); + inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env, + LongBuffer.wrap(tokens.types()), new long[] {1, tokens.types().length})); - return v[0][0]; + try (OrtSession.Result result = session.run(inputs)) { + // getValue() copies the tensor into Java arrays, so the result can be closed safely. + final float[][][] v = (float[][][]) result.get(0).getValue(); + return v[0][0]; + } + } finally { + inputs.values().forEach(OnnxTensor::close); + } } - private Tokens tokenize(final String text, Tokenizer tokenizer, Map vocab) { + /** + * Encodes text as model inputs: wordpiece token ids, an attention mask of ones, + * and single-segment (all zero) token type ids. + * + * @param text The text to encode. + * @param tokenizer The wordpiece tokenizer matching the {@code vocab}. + * @param vocab The vocabulary map. + * @return The encoded {@link Tokens}. + * + * @throws IllegalArgumentException Thrown if the tokenizer emits a token that is + * not present in the vocabulary. + */ + static Tokens tokenize(final String text, final Tokenizer tokenizer, + final Map vocab) { final String[] tokens = tokenizer.tokenize(text); - final int[] ids = new int[tokens.length]; - final long[] mask = new long[ids.length]; + final long[] ids = new long[tokens.length]; for (int x = 0; x < tokens.length; x++) { - ids[x] = vocab.get(tokens[x]); + final Integer id = vocab.get(tokens[x]); + if (id == null) { + throw new IllegalArgumentException("Token '" + tokens[x] + + "' is not present in the vocabulary; the vocabulary file does not match the model."); + } + ids[x] = id; } - final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray(); + final long[] mask = new long[ids.length]; + Arrays.fill(mask, 1); final long[] types = new long[ids.length]; - Arrays.fill(types, 1); - return new Tokens(tokens, lids, mask, types); + return new Tokens(tokens, ids, mask, types); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java new file mode 100644 index 000000000..422c7773a --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.dl.vectors; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import opennlp.dl.Tokens; +import opennlp.tools.tokenize.WordpieceTokenizer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SentenceVectorsDLTest { + + private static Map vocab() { + final Map vocab = new HashMap<>(); + vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0); + vocab.put(WordpieceTokenizer.BERT_SEP_TOKEN, 1); + vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2); + vocab.put("hello", 3); + vocab.put("world", 4); + return vocab; + } + + @Test + void testTokenizeUsesSingleSegmentBertEncoding() { + final Map vocab = vocab(); + final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet()); + + final Tokens tokens = SentenceVectorsDL.tokenize("hello world", tokenizer, vocab); + + assertArrayEquals(new String[] { + WordpieceTokenizer.BERT_CLS_TOKEN, "hello", "world", WordpieceTokenizer.BERT_SEP_TOKEN}, + tokens.tokens()); + assertArrayEquals(new long[] {0, 3, 4, 1}, tokens.ids()); + // The attention mask must be 1 for every real token. + assertArrayEquals(new long[] {1, 1, 1, 1}, tokens.mask()); + // Single-segment input: all token type ids must be 0. + assertArrayEquals(new long[] {0, 0, 0, 0}, tokens.types()); + } + + @Test + void testTokenizeMapsOutOfVocabularyWordsToUnknownToken() { + final Map vocab = vocab(); + final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet()); + + final Tokens tokens = SentenceVectorsDL.tokenize("hello xyz", tokenizer, vocab); + + assertArrayEquals(new long[] {0, 3, 2, 1}, tokens.ids()); + assertEquals(WordpieceTokenizer.BERT_UNK_TOKEN, tokens.tokens()[2]); + } + + @Test + void testTokenizeRejectsTokensMissingFromVocabulary() { + final Map vocab = vocab(); + vocab.remove(WordpieceTokenizer.BERT_UNK_TOKEN); + final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet()); + + assertThrows(IllegalArgumentException.class, () -> + SentenceVectorsDL.tokenize("hello xyz", tokenizer, vocab)); + } +} diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java index 286a092db..9a42e204f 100644 --- a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java +++ b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java @@ -38,10 +38,10 @@ public void generateVectorsTest() throws Exception { final float[] vectors = sv.getVectors(sentence); - Assertions.assertEquals(vectors[0], 0.39994872, 0.00001); - Assertions.assertEquals(vectors[1], -0.055101186, 0.00001); - Assertions.assertEquals(vectors[2], 0.2817594, 0.00001); - Assertions.assertEquals(vectors.length, 384); + Assertions.assertEquals(0.044745024, vectors[0], 0.00001); + Assertions.assertEquals(0.20219636, vectors[1], 0.00001); + Assertions.assertEquals(0.41306049, vectors[2], 0.00001); + Assertions.assertEquals(384, vectors.length); } }