120 lines
4.5 KiB
Python
120 lines
4.5 KiB
Python
|
"""
|
||
|
pybergamot - (Somewhat) stable interface for the **Bergamot Translation Engine Python Bindings**.
|
||
|
Copyright (C) 2023 Ad5001
|
||
|
|
||
|
This program is free software: you can redistribute it and/or modify
|
||
|
it under the terms of the GNU General Public License as published by
|
||
|
the Free Software Foundation, either version 3 of the License, or
|
||
|
(at your option) any later version.
|
||
|
|
||
|
This program is distributed in the hope that it will be useful,
|
||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
GNU General Public License for more details.
|
||
|
|
||
|
You should have received a copy of the GNU General Public License
|
||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
|
|
||
|
|
||
|
Lists all repositories and connects them to default language.
|
||
|
"""
|
||
|
|
||
|
from bergamot import REPOSITORY
|
||
|
from languagecodes import iso_639_alpha2
|
||
|
|
||
|
|
||
|
class Models:
|
||
|
REPO_FOR_MODEL = {}
|
||
|
AVAILABLE = []
|
||
|
INSTALLED = []
|
||
|
LANGS = []
|
||
|
|
||
|
@staticmethod
|
||
|
def update_models_list() -> None:
|
||
|
"""
|
||
|
Imports the list of models from all repositories.
|
||
|
"""
|
||
|
Models.REPO_FOR_MODEL = {model_name: repo
|
||
|
for repo in REPOSITORY.repositories
|
||
|
for model_name in REPOSITORY.repositories[repo].models(False)}
|
||
|
Models.AVAILABLE = [model_name
|
||
|
for repo in REPOSITORY.repositories
|
||
|
for model_name in REPOSITORY.repositories[repo].models(False)]
|
||
|
Models.INSTALLED = [model_name
|
||
|
for repo in REPOSITORY.repositories
|
||
|
for model_name in REPOSITORY.repositories[repo].models(True)]
|
||
|
Models.LANGS = []
|
||
|
for model_name in Models.AVAILABLE:
|
||
|
lang1, lang2 = Models.get_model_languages(model_name)
|
||
|
if lang1 not in Models.LANGS:
|
||
|
Models.LANGS.append(lang1)
|
||
|
if lang2 not in Models.LANGS:
|
||
|
Models.LANGS.append(lang2)
|
||
|
|
||
|
@staticmethod
|
||
|
def update_repositories_cache() -> None:
|
||
|
"""
|
||
|
Fetches the online models list for every repository.
|
||
|
"""
|
||
|
for repo in REPOSITORY.repositories:
|
||
|
REPOSITORY.repositories[repo].update()
|
||
|
Models.update_models_list()
|
||
|
|
||
|
@staticmethod
|
||
|
def get_model_languages(model_name: str) -> tuple:
|
||
|
"""
|
||
|
Returns a tuple of two two-char ISO language name which the model translates from and to.
|
||
|
:param model_name: Name of the model
|
||
|
:raises:
|
||
|
ValueError: When the model_name doesn't exist.
|
||
|
:return: (from language, to language)
|
||
|
"""
|
||
|
if model_name not in Models.AVAILABLE:
|
||
|
raise ValueError(f"Model {model_name} does not exist. Did you update the repository cache?")
|
||
|
model = REPOSITORY.model(Models.REPO_FOR_MODEL[model_name], model_name)
|
||
|
src, target, tiny = model['code'].split("-")
|
||
|
if len(src) == 3:
|
||
|
src = iso_639_alpha2(src)
|
||
|
if len(target) == 3:
|
||
|
target = iso_639_alpha2(target)
|
||
|
return src, target
|
||
|
|
||
|
@staticmethod
|
||
|
def get_model_name_for_languages(source_lang: str, target_lang: str) -> str | None:
|
||
|
"""
|
||
|
Finds a model which translates source_lang into target_lang.
|
||
|
:param source_lang: Language to translate from.
|
||
|
:param target_lang: Language to translate to.
|
||
|
:return: None if no model was found, name of the model otherwise.
|
||
|
"""
|
||
|
lang_tuple = (source_lang, target_lang)
|
||
|
names = list(filter(lambda name: lang_tuple == Models.get_model_languages(name), Models.AVAILABLE))
|
||
|
if len(names) > 0:
|
||
|
model_name = names[0]
|
||
|
else:
|
||
|
model_name = None
|
||
|
return model_name
|
||
|
|
||
|
@staticmethod
|
||
|
def download(model_name: str) -> None:
|
||
|
"""
|
||
|
Downloads or updates the given model.
|
||
|
:param model_name: Name of the model to download.
|
||
|
:raises:
|
||
|
ValueError: When the model_name doesn't exist.
|
||
|
"""
|
||
|
if model_name not in Models.AVAILABLE:
|
||
|
raise ValueError(f"Model {model_name} does not exist. Did you update the repository cache?")
|
||
|
REPOSITORY.download(Models.REPO_FOR_MODEL[model_name], model_name)
|
||
|
|
||
|
@staticmethod
|
||
|
def update_all_models() -> None:
|
||
|
"""
|
||
|
Updates all already downloaded models to their latest versions.
|
||
|
"""
|
||
|
for model_name in Models.INSTALLED:
|
||
|
REPOSITORY.download(Models.REPO_FOR_MODEL[model_name], model_name)
|
||
|
|
||
|
|
||
|
Models.update_models_list()
|