pybergamot/pybergamot/models.py
2023-11-03 01:33:56 +01:00

141 lines
4.7 KiB
Python

"""
pybergamot - (Somewhat) stable interface for the **Bergamot Translation Engine Python Bindings**.
Copyright (C) 2023 Ad5001
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Lists all repositories and connects them to default language.
"""
from bergamot import REPOSITORY
from languagecodes import iso_639_alpha2
class Models:
REPO_FOR_MODEL = {}
AVAILABLE = []
INSTALLED = []
LANGS = []
@staticmethod
def update_models_list() -> None:
"""
Imports the list of models from all repositories.
"""
Models.REPO_FOR_MODEL = {model_name: repo
for repo in REPOSITORY.repositories
for model_name in REPOSITORY.repositories[repo].models(False)}
Models.AVAILABLE = [model_name
for repo in REPOSITORY.repositories
for model_name in REPOSITORY.repositories[repo].models(False)]
Models.INSTALLED = [model_name
for repo in REPOSITORY.repositories
for model_name in REPOSITORY.repositories[repo].models(True)]
Models.LANGS = []
for model_name in Models.AVAILABLE:
lang1, lang2 = Models.get_model_languages(model_name)
if lang1 not in Models.LANGS:
Models.LANGS.append(lang1)
if lang2 not in Models.LANGS:
Models.LANGS.append(lang2)
@staticmethod
def update_repositories_cache() -> None:
"""
Fetches the online models list for every repository.
"""
for repo in REPOSITORY.repositories:
REPOSITORY.repositories[repo].update()
Models.update_models_list()
@staticmethod
def get_model_languages(model_name: str) -> tuple:
"""
Returns a tuple of two two-char ISO language name which the model translates from and to.
## Parameters
---
- **model_name**: Name of the model
## Exceptions
---
- `ValueError`: When the model_name doesn't exist.
## Returns
---
(from language, to language)
"""
if model_name not in Models.AVAILABLE:
raise ValueError(f"Model {model_name} does not exist. Did you update the repository cache?")
model = REPOSITORY.model(Models.REPO_FOR_MODEL[model_name], model_name)
src, target, tiny = model['code'].split("-")
if len(src) == 3:
src = iso_639_alpha2(src)
if len(target) == 3:
target = iso_639_alpha2(target)
return src, target
@staticmethod
def get_model_name_for_languages(source_lang: str, target_lang: str) -> str | None:
"""
Finds a model which translates source_lang into target_lang.
## Parameters
---
- **source_lang**: Language to translate from.
- **target_lang**: Language to translate to.
## Returns
---
None if no model was found, name of the model otherwise.
"""
lang_tuple = (source_lang, target_lang)
names = list(filter(lambda name: lang_tuple == Models.get_model_languages(name), Models.AVAILABLE))
if len(names) > 0:
model_name = names[0]
else:
model_name = None
return model_name
@staticmethod
def download(model_name: str) -> None:
"""
Downloads or updates the given model.
## Parameters
---
- **model_name**: Name of the model to download.
## Exceptions
---
- `ValueError`: When the model_name doesn't exist.
"""
if model_name not in Models.AVAILABLE:
raise ValueError(f"Model {model_name} does not exist. Did you update the repository cache?")
REPOSITORY.download(Models.REPO_FOR_MODEL[model_name], model_name)
@staticmethod
def update_all_models() -> None:
"""
Updates all already downloaded models to their latest versions.
"""
for model_name in Models.INSTALLED:
REPOSITORY.download(Models.REPO_FOR_MODEL[model_name], model_name)
Models.update_models_list()