pybergamot/pybergamot/translator.py
2023-11-03 00:36:46 +01:00

155 lines
7.4 KiB
Python

"""
pybergamot - (Somewhat) stable interface for the **Bergamot Translation Engine Python Bindings**.
Copyright (C) 2023 Ad5001
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from bergamot import REPOSITORY, TranslationModel, Service, ServiceConfig
from warnings import warn
from .models import Models
from .engine import Engine, DirectBergamotModelEngine, ChainBergamotModelsEngine
class Translator:
"""
Main exposed class to provide translation using Bergamot.
Workflow goes as follows:
1. Create instance
2. Load languages
3. Use translation between any of the loaded language.
"""
def __init__(self, workers_count = 1, cache_size = 0, log_level = 'off'):
"""
Creates a Translator instance.
:param workers_count: Number of workers which can be used at once.
:param cache_size: Size of the cache used in bergamot..
:param log_level: Level of logs used in bergamot.
"""
self.loaded_languages = []
self._loaded_engines = {}
config = ServiceConfig(numWorkers=workers_count, cacheSize=cache_size, logLevel=log_level)
self.service = Service(config)
def _load_model(self, model_name: str, download: bool = True) -> TranslationModel:
"""
Loads a tiny model by its name, downloads it if it doesn't exist.
:param model_name: Name of the model to load.
:param download: If a model does not exist locally, if True, download it,
otherwise emit an error.
:raises:
ValueError: If the provided model does not exist.
EnvironmentError: When a model is unavailable and download has been set to false.
:return: Bergamot translation model instance.
"""
if model_name not in Models.AVAILABLE:
raise ValueError(f"Model {model_name} not available.")
# Check if the model needs to be downloaded.
if model_name not in Models.INSTALLED:
if download:
Models.download(model_name)
else:
langs = Models.get_model_languages(model_name)
raise EnvironmentError(f"Translation model from {langs[0]} to {langs[1]} is not installed locally.")
# Create model
model_path = REPOSITORY.modelConfigPath(Models.REPO_FOR_MODEL[model_name], model_name)
return self.service.modelFromConfigPath(model_path)
def _create_engine(self, source_lang: str, target_lang: str, download: bool = True) -> Engine:
"""
Creates an Engine to translate a source lang to a target lang.
:param source_lang: Language to translate from.
:param target_lang: Language to translate to.
:param download: If a model does not exist locally, if True, download it,
otherwise emit an error.
:raises:
ValueError: If a model from a lang to english does not exist.
EnvironmentError: When a model is unavailable and download has been set to false.
:return: Engine instance.
"""
direct_model_name = Models.get_model_name_for_languages(source_lang, target_lang)
if direct_model_name is not None and (download or direct_model_name in Models.INSTALLED):
# Direct model exists, and is installed locally if option download is disabled.
engine = DirectBergamotModelEngine(
source_lang, target_lang, self._load_model(direct_model_name, download), self.service
)
else:
# Use chain models with English as intermediary.
model1 = Models.get_model_name_for_languages(source_lang, "en")
model2 = Models.get_model_name_for_languages("en", target_lang)
if model1 is None:
raise ValueError(f"Missing translation models between English and {source_lang}.")
if model2 is None:
raise ValueError(f"Missing translation models between English and {target_lang}.")
# Create the engine
engine = ChainBergamotModelsEngine(
source_lang, target_lang,
self._load_model(model1, download), self._load_model(model2, download),
self.service
)
return engine
def load(self, lang: str, download: bool = True) -> None:
"""
Loads a language code and all the associated models (for already added languages)
into the translator.
:param lang: Two-char ISO language name.
:param download: If a model does not exist locally, if True, download it,
otherwise emit an error.
:raises:
ValueError: If a model from a lang to english does not exist.
EnvironmentError: When a model is unavailable and download has been set to false.
"""
if lang not in Models.LANGS:
raise ValueError(f"Language {lang} does not exist.")
if lang in self.loaded_languages:
warn(f"Language {lang} has already been imported.", RuntimeWarning)
else:
# Register language
self._loaded_engines[lang] = {}
# Find whether there is a direct model for translating with other loaded language
# or we need to use a pivot
for other_lang in self.loaded_languages:
forward_engine = self._create_engine(lang, other_lang, download)
backward_engine = self._create_engine(other_lang, lang, download)
self._loaded_engines[lang][other_lang] = forward_engine
self._loaded_engines[other_lang][lang] = backward_engine
# Register language
self.loaded_languages.append(lang)
def translate(self, source_lang: str, target_lang: str, text: str,
html: bool = False, alignment: bool = False, quality_scores: bool = False) -> str:
"""
Translates a text from a source lang to a target lang.
:param source_lang: Language to translate from.
:param target_lang: Language to translate to.
:param text: Text to translate.
:param html: Set to True if the text contains an HTML structure which needs to
be preserved while translated.
:param alignment: Toggle for alignment.
:param quality_scores: Toggle for whether to include the translation's quality scores
for each word in HTML format.
:raises:
ValueError: Either source_lang or target_lang haven't been loaded yet.
:return: The translated text.
"""
if source_lang not in self.loaded_languages:
raise ValueError(f"Language {source_lang} is not loaded. Use the load() function first.")
if target_lang not in self.loaded_languages:
raise ValueError(f"Language {target_lang} is not loaded. Use the load() function first.")
return self._loaded_engines[source_lang][target_lang].translate(text, html, alignment, quality_scores)