diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index 85abfbe4f..a028ac4b4 100644
--- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -37,11 +37,15 @@
/**
* Facilitates the generation of sentence vectors using
* a sentence-transformers model converted to ONNX.
+ *
+ *
The model inputs follow the standard single-segment BERT
+ * encoding: {@code attention_mask} is {@code 1} for every real
+ * token and {@code token_type_ids} is {@code 0} throughout.
*/
public class SentenceVectorsDL extends AbstractDL {
/**
- * Instantiates a {@link SentenceVectorsDL sentence detector} using ONNX models.
+ * Instantiates a {@link SentenceVectorsDL sentence vector generator} using ONNX models.
*
* @param model The file name of a sentence vectors ONNX model.
* @param vocabulary The file name of the vocabulary file for the model.
@@ -54,7 +58,7 @@ public SentenceVectorsDL(final File model, final File vocabulary)
env = OrtEnvironment.getEnvironment();
session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
- vocab = loadVocab(new File(vocabulary.getPath()));
+ vocab = loadVocab(vocabulary);
tokenizer = createTokenizer(vocab);
}
@@ -63,6 +67,7 @@ public SentenceVectorsDL(final File model, final File vocabulary)
* Generates vectors given a sentence.
*
* @param sentence The input sentence.
+ * @return The sentence vector.
*
* @throws OrtException Thrown if an error occurs during inference.
*/
@@ -72,38 +77,61 @@ public float[] getVectors(final String sentence) throws OrtException {
final Map inputs = new HashMap<>();
- inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids()),
- new long[] {1, tokens.ids().length}));
-
- inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
- LongBuffer.wrap(tokens.mask()), new long[] {1, tokens.mask().length}));
+ try {
+ inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids()),
+ new long[] {1, tokens.ids().length}));
- inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
- LongBuffer.wrap(tokens.types()), new long[] {1, tokens.types().length}));
+ inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.mask()), new long[] {1, tokens.mask().length}));
- final float[][][] v = (float[][][]) session.run(inputs).get(0).getValue();
+ inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.types()), new long[] {1, tokens.types().length}));
- return v[0][0];
+ try (OrtSession.Result result = session.run(inputs)) {
+ // getValue() copies the tensor into Java arrays, so the result can be closed safely.
+ final float[][][] v = (float[][][]) result.get(0).getValue();
+ return v[0][0];
+ }
+ } finally {
+ inputs.values().forEach(OnnxTensor::close);
+ }
}
- private Tokens tokenize(final String text, Tokenizer tokenizer, Map vocab) {
+ /**
+ * Encodes text as model inputs: wordpiece token ids, an attention mask of ones,
+ * and single-segment (all zero) token type ids.
+ *
+ * @param text The text to encode.
+ * @param tokenizer The wordpiece tokenizer matching the {@code vocab}.
+ * @param vocab The vocabulary map.
+ * @return The encoded {@link Tokens}.
+ *
+ * @throws IllegalArgumentException Thrown if the tokenizer emits a token that is
+ * not present in the vocabulary.
+ */
+ static Tokens tokenize(final String text, final Tokenizer tokenizer,
+ final Map vocab) {
final String[] tokens = tokenizer.tokenize(text);
- final int[] ids = new int[tokens.length];
- final long[] mask = new long[ids.length];
+ final long[] ids = new long[tokens.length];
for (int x = 0; x < tokens.length; x++) {
- ids[x] = vocab.get(tokens[x]);
+ final Integer id = vocab.get(tokens[x]);
+ if (id == null) {
+ throw new IllegalArgumentException("Token '" + tokens[x]
+ + "' is not present in the vocabulary; the vocabulary file does not match the model.");
+ }
+ ids[x] = id;
}
- final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
+ final long[] mask = new long[ids.length];
+ Arrays.fill(mask, 1);
final long[] types = new long[ids.length];
- Arrays.fill(types, 1);
- return new Tokens(tokens, lids, mask, types);
+ return new Tokens(tokens, ids, mask, types);
}
diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java
new file mode 100644
index 000000000..422c7773a
--- /dev/null
+++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLTest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.vectors;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.dl.Tokens;
+import opennlp.tools.tokenize.WordpieceTokenizer;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class SentenceVectorsDLTest {
+
+ private static Map vocab() {
+ final Map vocab = new HashMap<>();
+ vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0);
+ vocab.put(WordpieceTokenizer.BERT_SEP_TOKEN, 1);
+ vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
+ vocab.put("hello", 3);
+ vocab.put("world", 4);
+ return vocab;
+ }
+
+ @Test
+ void testTokenizeUsesSingleSegmentBertEncoding() {
+ final Map vocab = vocab();
+ final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet());
+
+ final Tokens tokens = SentenceVectorsDL.tokenize("hello world", tokenizer, vocab);
+
+ assertArrayEquals(new String[] {
+ WordpieceTokenizer.BERT_CLS_TOKEN, "hello", "world", WordpieceTokenizer.BERT_SEP_TOKEN},
+ tokens.tokens());
+ assertArrayEquals(new long[] {0, 3, 4, 1}, tokens.ids());
+ // The attention mask must be 1 for every real token.
+ assertArrayEquals(new long[] {1, 1, 1, 1}, tokens.mask());
+ // Single-segment input: all token type ids must be 0.
+ assertArrayEquals(new long[] {0, 0, 0, 0}, tokens.types());
+ }
+
+ @Test
+ void testTokenizeMapsOutOfVocabularyWordsToUnknownToken() {
+ final Map vocab = vocab();
+ final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet());
+
+ final Tokens tokens = SentenceVectorsDL.tokenize("hello xyz", tokenizer, vocab);
+
+ assertArrayEquals(new long[] {0, 3, 2, 1}, tokens.ids());
+ assertEquals(WordpieceTokenizer.BERT_UNK_TOKEN, tokens.tokens()[2]);
+ }
+
+ @Test
+ void testTokenizeRejectsTokensMissingFromVocabulary() {
+ final Map vocab = vocab();
+ vocab.remove(WordpieceTokenizer.BERT_UNK_TOKEN);
+ final WordpieceTokenizer tokenizer = new WordpieceTokenizer(vocab.keySet());
+
+ assertThrows(IllegalArgumentException.class, () ->
+ SentenceVectorsDL.tokenize("hello xyz", tokenizer, vocab));
+ }
+}
diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
index 286a092db..9a42e204f 100644
--- a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
+++ b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
@@ -38,10 +38,10 @@ public void generateVectorsTest() throws Exception {
final float[] vectors = sv.getVectors(sentence);
- Assertions.assertEquals(vectors[0], 0.39994872, 0.00001);
- Assertions.assertEquals(vectors[1], -0.055101186, 0.00001);
- Assertions.assertEquals(vectors[2], 0.2817594, 0.00001);
- Assertions.assertEquals(vectors.length, 384);
+ Assertions.assertEquals(0.044745024, vectors[0], 0.00001);
+ Assertions.assertEquals(0.20219636, vectors[1], 0.00001);
+ Assertions.assertEquals(0.41306049, vectors[2], 0.00001);
+ Assertions.assertEquals(384, vectors.length);
}
}