diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7979518 --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +*.manifest +*.spec + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db diff --git a/ruaccent/accent_model.py b/ruaccent/accent_model.py index ee4b255..4b23c8f 100644 --- a/ruaccent/accent_model.py +++ b/ruaccent/accent_model.py @@ -33,6 +33,11 @@ def put_accent(self, word): lower_word = word.lower() inputs = self.tokenizer(lower_word, return_tensors="np") inputs = {k: v.astype(np.int64) for k, v in inputs.items()} + + # Add token_type_ids if missing (zeros with same shape as input_ids) + if 'token_type_ids' not in inputs: + inputs['token_type_ids'] = np.zeros_like(inputs['input_ids']) + outputs = self.session.run(None, inputs) output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} logits = outputs[output_names["logits"]]