Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hezar/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class PrecisionType(ExplicitEnum):
class OptimizerType(ExplicitEnum):
ADAM = "adam"
ADAMW = "adamw"
SDG = "sdg"
SGD = "sgd"


class LRSchedulerType(ExplicitEnum):
Expand All @@ -143,7 +143,7 @@ class LRSchedulerType(ExplicitEnum):
CYCLIC = "cyclic"
SEQUENTIAL = "sequential"
POLYNOMIAL = "polynomial"
COSINE_ANEALING = "cosine_anealing"
COSINE_ANNEALING = "cosine_annealing"


class SplitType(ExplicitEnum):
Expand Down
5 changes: 4 additions & 1 deletion hezar/data/data_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __call__(self, input_batch):
Returns:
Dict: The same batch dictionary but padded
"""
self.tokenizer.config.padding_side = self.padding_side
input_batch = self.tokenizer.pad_encoded_batch(
input_batch,
padding=self.padding,
Expand Down Expand Up @@ -131,6 +132,7 @@ def __call__(self, input_batch):
Returns:
Dict: The same batch dictionary but padded
"""
self.tokenizer.config.padding_side = self.padding_side
input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch]
input_batch = _convert_to_batch_dict(input_batch)
padded_batch = self.tokenizer.pad_encoded_batch(
Expand Down Expand Up @@ -176,6 +178,7 @@ def __init__(
self.max_length = max_length

def __call__(self, input_batch):
self.tokenizer.config.padding_side = self.padding_side
input_batch = _convert_to_batch_dict(input_batch)
input_batch = self.tokenizer.pad_encoded_batch(
input_batch,
Expand Down Expand Up @@ -329,7 +332,7 @@ def __call__(self, input_batch):
max_length = max(map(len, input_batch["labels"]))
all_labels = []
for labels in input_batch["labels"]:
labels += [self.pad_token_id] * (max_length - len(labels))
labels = labels + [self.pad_token_id] * (max_length - len(labels))
all_labels.append(labels)
input_batch["labels"] = torch.tensor(all_labels)
return input_batch
2 changes: 1 addition & 1 deletion hezar/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _create_indices(self):
return indices

def __len__(self):
return self.total_length
return len(self.indices)

def __iter__(self):
for indice in self.indices:
Expand Down
2 changes: 1 addition & 1 deletion hezar/data/dataset_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _text_to_ids(self, text):

"""
if self.text_split_type == "tokenize" and self.tokenizer:
token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["input_ids"]
token_ids = self.tokenizer(text, padding="max_length", max_length=self.max_length)["token_ids"]
labels = [token_id if token_id != self.tokenizer.pad_token_id else -100 for token_id in token_ids]
elif self.text_split_type == "char_split":
if self.reverse_digits:
Expand Down
11 changes: 6 additions & 5 deletions hezar/data/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
**kwargs,
):
verify_dependencies(self, self.required_backends)
self.cache_dir = kwargs.pop("cache_dir", None) or self.cache_dir
self.config = config.update(kwargs)
self.split = split
self.data = self._load(self.split)
Expand Down Expand Up @@ -160,13 +161,12 @@ def load(
split = split or "train"
config_filename = config_filename or cls.config_filename

if ":" in hub_path:
hub_path, hf_dataset_config_name = hub_path.split(":")
if ":" in hub_path and not os.path.exists(hub_path):
hub_path, hf_dataset_config_name = hub_path.split(":", 1)
kwargs["hf_load_kwargs"] = kwargs.get("hf_load_kwargs", {})
kwargs["hf_load_kwargs"]["name"] = hf_dataset_config_name

if cache_dir is not None:
cls.cache_dir = cache_dir
cache_dir = cache_dir or cls.cache_dir

has_config = config_filename in list_repo_files(hub_path, repo_type="dataset")

Expand All @@ -177,7 +177,7 @@ def load(
hub_path,
filename=config_filename,
repo_type=RepoType.DATASET,
cache_dir=cls.cache_dir,
cache_dir=cache_dir,
**kwargs,
)
elif kwargs.get("task", None):
Expand All @@ -199,6 +199,7 @@ def load(
config=dataset_config,
split=split,
preprocessor=preprocessor,
cache_dir=cache_dir,
**kwargs,
)
return dataset
4 changes: 2 additions & 2 deletions hezar/data/datasets/ocr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class OCRDatasetConfig(DatasetConfig):
images_paths_column: str = "image_path"
max_length: int | None = None
invalid_characters: list | None = None
reverse_text: bool | None = None
reverse_digits: bool | None = None


Expand Down Expand Up @@ -110,7 +109,8 @@ def _load(self, split=None):
invalid_indices = []
for i, sample in enumerate(list(iter(data))):
path, text = sample.values()
if len(text) <= self.config.max_length and is_text_valid(text, self.config.id2label.values()):
within_len = self.config.max_length is None or len(text) <= self.config.max_length
if within_len and is_text_valid(text, self.config.id2label.values()):
valid_indices.append(i)
else:
invalid_indices.append(i)
Expand Down
2 changes: 1 addition & 1 deletion hezar/data/datasets/text_summarization_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __getitem__(self, index):
)
labels = self.tokenizer(
summary,
max_length=self.config.max_length,
max_length=self.config.labels_max_length,
padding="max_length" if self.config.labels_max_length else None,
return_attention_mask=True,
)
Expand Down
11 changes: 7 additions & 4 deletions hezar/embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import os
import tempfile
from typing import Self
Expand Down Expand Up @@ -43,8 +44,6 @@ def __init__(
self, config: EmbeddingConfig, embedding_file: str | None = None, vectors_file: str | None = None, **kwargs
):
verify_dependencies(self, self.required_backends) # Check if all the required dependencies are installed
self.config = config.update(kwargs)

self.config = config.update(kwargs)
self.model = self.from_file(embedding_file, vectors_file) if embedding_file else self.build()

Expand Down Expand Up @@ -102,6 +101,11 @@ def word2index(self, word):
"""
return self.vocab.get(word, -1)

@functools.cached_property
def _index2word(self):
"""Reverse vocabulary mapping (index -> word), built once and cached."""
return {v: k for k, v in self.vocab.items()}

def index2word(self, index):
"""
Get the word corresponding to a given index.
Expand All @@ -112,8 +116,7 @@ def index2word(self, index):
Returns:
str: Word corresponding to the index.
"""
keyed_vocab = {v: k for k, v in self.vocab.items()}
return keyed_vocab[index]
return self._index2word[index]

def similarity(self, word1: str, word2: str):
"""
Expand Down
4 changes: 2 additions & 2 deletions hezar/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def compute(
Returns:
A dictionary of the metric results
"""
normalize = normalize or self.config.normalize
normalize = normalize if normalize is not None else self.config.normalize
sample_weight = sample_weight or self.config.sample_weight
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

score = accuracy_score(
Expand Down
5 changes: 3 additions & 2 deletions hezar/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def compute(
Returns:
dict: A dictionary of the metric results, with keys specified by `output_keys`.
"""
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

predictions = [x.split() if isinstance(x, str) else x for x in predictions] # ty:ignore
targets = [x.split() if isinstance(x, str) else x for x in targets] # ty:ignore
# Each hypothesis needs a *list of references*; wrap a tokenized string reference accordingly.
targets = [[x.split()] if isinstance(x, str) else x for x in targets] # ty:ignore

score = corpus_bleu(targets, predictions, weights=weights)

Expand Down
2 changes: 1 addition & 1 deletion hezar/metrics/cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compute(
dict: A dictionary of the metric results, with keys specified by `output_keys`.
"""
concatenate_texts = concatenate_texts or self.config.concatenate_texts
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals

if concatenate_texts:
score = jiwer.process_words(
Expand Down
2 changes: 1 addition & 1 deletion hezar/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def compute(
pos_label = pos_label or self.config.pos_label
average = average or self.config.average
sample_weight = sample_weight or self.config.sample_weight
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

score = f1_score(
Expand Down
2 changes: 1 addition & 1 deletion hezar/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def compute(
average = average or self.config.average
sample_weight = sample_weight or self.config.sample_weight
zero_division = zero_division or self.config.zero_division
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

score = precision_score(
Expand Down
2 changes: 1 addition & 1 deletion hezar/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def compute(
average = average or self.config.average
sample_weight = sample_weight or self.config.sample_weight
zero_division = zero_division or self.config.zero_division
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

score = recall_score(
Expand Down
39 changes: 23 additions & 16 deletions hezar/metrics/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,28 @@ def compute(
Returns:
dict: A dictionary of the metric results, with keys specified by `output_keys`.
"""
aggregator = scoring.BootstrapAggregator()

for ref, pred in zip(targets, predictions, strict=True):
if self.config.multi_ref:
score = self.scorer.score_multi(ref, pred)
else:
score = self.scorer.score(ref, pred)

aggregator.add_scores(score)

results = aggregator.aggregate()
for key in results:
results[key] = results[key].mid.fmeasure

if output_keys:
results = {k: round(v, 4) for k, v in results.items() if k in output_keys}
use_aggregator = use_aggregator if use_aggregator is not None else self.config.use_aggregator
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

score_fn = self.scorer.score_multi if self.config.multi_ref else self.scorer.score

if use_aggregator:
aggregator = scoring.BootstrapAggregator()
for ref, pred in zip(targets, predictions, strict=True):
aggregator.add_scores(score_fn(ref, pred))
agg = aggregator.aggregate()
results = {k: agg[k].mid.fmeasure for k in agg}
else:
sums = {}
n = 0
for ref, pred in zip(targets, predictions, strict=True):
scores = score_fn(ref, pred)
for k, v in scores.items():
sums[k] = sums.get(k, 0.0) + v.fmeasure
n += 1
results = {k: (v / n if n else 0.0) for k, v in sums.items()}

results = {k: round(v, n_decimals) for k, v in results.items() if (not output_keys or k in output_keys)}

return results
4 changes: 2 additions & 2 deletions hezar/metrics/seqeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def compute(
mode = mode or self.config.mode
sample_weight = sample_weight or self.config.sample_weight
zero_division = zero_division or self.config.zero_division
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals
output_keys = output_keys or self.config.output_keys

report = classification_report(
Expand All @@ -106,7 +106,7 @@ def compute(
overall_score = report.pop("micro avg")

results = {
"accuracy": format(accuracy_score(predictions, targets)),
"accuracy": accuracy_score(targets, predictions),
"f1": overall_score["f1-score"],
"recall": overall_score["recall"],
"precision": overall_score["precision"],
Expand Down
2 changes: 1 addition & 1 deletion hezar/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def compute(
dict: A dictionary of the metric results, with keys specified by `output_keys`.
"""
concatenate_texts = concatenate_texts or self.config.concatenate_texts
n_decimals = n_decimals or self.config.n_decimals
n_decimals = n_decimals if n_decimals is not None else self.config.n_decimals

if concatenate_texts:
score = jiwer.process_words(targets, predictions).wer
Expand Down
6 changes: 4 additions & 2 deletions hezar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
losses_mapping = {
LossType.L1: nn.L1Loss,
LossType.NLL: nn.NLLLoss,
LossType.NLL_2D: nn.NLLLoss2d, # ty:ignore
LossType.NLL_2D: nn.NLLLoss,
LossType.POISSON_NLL: nn.PoissonNLLLoss,
LossType.GAUSSIAN_NLL: nn.GaussianNLLLoss,
LossType.MSE: nn.MSELoss,
Expand Down Expand Up @@ -128,7 +128,7 @@ def load(
The fully loaded Hezar model
"""
# Get device if provided in the kwargs
device = None or kwargs.pop("device", None)
device = kwargs.pop("device", None)
# Load config
config_filename = config_filename or cls.config_filename
cache_dir = cache_dir or HEZAR_CACHE_DIR
Expand Down Expand Up @@ -277,6 +277,8 @@ def save(
if self.preprocessor is not None:
self.preprocessor.save(path)

return model_save_path

def push_to_hub(
self,
repo_id: str,
Expand Down
4 changes: 2 additions & 2 deletions hezar/models/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def items(self):

@dataclass(repr=False)
class MaskFillingOutput(ModelOutput):
token: Optional[int] = None
token: str | None = None
sequence: str | None = None
token_id: str | None = None
token_id: int | None = None
score: Optional[float] = None


Expand Down
Loading