155 lines
7.4 KiB
Python
155 lines
7.4 KiB
Python
"""
|
|
pybergamot - (Somewhat) stable interface for the **Bergamot Translation Engine Python Bindings**.
|
|
Copyright (C) 2023 Ad5001
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from bergamot import REPOSITORY, TranslationModel, Service, ServiceConfig
|
|
from warnings import warn
|
|
|
|
from .models import Models
|
|
from .engine import Engine, DirectBergamotModelEngine, ChainBergamotModelsEngine
|
|
|
|
|
|
class Translator:
|
|
"""
|
|
Main exposed class to provide translation using Bergamot.
|
|
Workflow goes as follows:
|
|
1. Create instance
|
|
2. Load languages
|
|
3. Use translation between any of the loaded language.
|
|
|
|
"""
|
|
def __init__(self, workers_count = 1, cache_size = 0, log_level = 'off'):
|
|
"""
|
|
Creates a Translator instance.
|
|
:param workers_count: Number of workers which can be used at once.
|
|
:param cache_size: Size of the cache used in bergamot..
|
|
:param log_level: Level of logs used in bergamot.
|
|
"""
|
|
self.loaded_languages = []
|
|
self._loaded_engines = {}
|
|
config = ServiceConfig(numWorkers=workers_count, cacheSize=cache_size, logLevel=log_level)
|
|
self.service = Service(config)
|
|
|
|
def _load_model(self, model_name: str, download: bool = True) -> TranslationModel:
|
|
"""
|
|
Loads a tiny model by its name, downloads it if it doesn't exist.
|
|
:param model_name: Name of the model to load.
|
|
:param download: If a model does not exist locally, if True, download it,
|
|
otherwise emit an error.
|
|
:raises:
|
|
ValueError: If the provided model does not exist.
|
|
EnvironmentError: When a model is unavailable and download has been set to false.
|
|
:return: Bergamot translation model instance.
|
|
"""
|
|
if model_name not in Models.AVAILABLE:
|
|
raise ValueError(f"Model {model_name} not available.")
|
|
# Check if the model needs to be downloaded.
|
|
if model_name not in Models.INSTALLED:
|
|
if download:
|
|
Models.download(model_name)
|
|
else:
|
|
langs = Models.get_model_languages(model_name)
|
|
raise EnvironmentError(f"Translation model from {langs[0]} to {langs[1]} is not installed locally.")
|
|
# Create model
|
|
model_path = REPOSITORY.modelConfigPath(Models.REPO_FOR_MODEL[model_name], model_name)
|
|
return self.service.modelFromConfigPath(model_path)
|
|
|
|
def _create_engine(self, source_lang: str, target_lang: str, download: bool = True) -> Engine:
|
|
"""
|
|
Creates an Engine to translate a source lang to a target lang.
|
|
:param source_lang: Language to translate from.
|
|
:param target_lang: Language to translate to.
|
|
:param download: If a model does not exist locally, if True, download it,
|
|
otherwise emit an error.
|
|
:raises:
|
|
ValueError: If a model from a lang to english does not exist.
|
|
EnvironmentError: When a model is unavailable and download has been set to false.
|
|
:return: Engine instance.
|
|
"""
|
|
direct_model_name = Models.get_model_name_for_languages(source_lang, target_lang)
|
|
if direct_model_name is not None and (download or direct_model_name in Models.INSTALLED):
|
|
# Direct model exists, and is installed locally if option download is disabled.
|
|
engine = DirectBergamotModelEngine(
|
|
source_lang, target_lang, self._load_model(direct_model_name, download), self.service
|
|
)
|
|
else:
|
|
# Use chain models with English as intermediary.
|
|
model1 = Models.get_model_name_for_languages(source_lang, "en")
|
|
model2 = Models.get_model_name_for_languages("en", target_lang)
|
|
if model1 is None:
|
|
raise ValueError(f"Missing translation models between English and {source_lang}.")
|
|
if model2 is None:
|
|
raise ValueError(f"Missing translation models between English and {target_lang}.")
|
|
# Create the engine
|
|
engine = ChainBergamotModelsEngine(
|
|
source_lang, target_lang,
|
|
self._load_model(model1, download), self._load_model(model2, download),
|
|
self.service
|
|
)
|
|
return engine
|
|
|
|
def load(self, lang: str, download: bool = True) -> None:
|
|
"""
|
|
Loads a language code and all the associated models (for already added languages)
|
|
into the translator.
|
|
|
|
:param lang: Two-char ISO language name.
|
|
:param download: If a model does not exist locally, if True, download it,
|
|
otherwise emit an error.
|
|
:raises:
|
|
ValueError: If a model from a lang to english does not exist.
|
|
EnvironmentError: When a model is unavailable and download has been set to false.
|
|
"""
|
|
if lang not in Models.LANGS:
|
|
raise ValueError(f"Language {lang} does not exist.")
|
|
if lang in self.loaded_languages:
|
|
warn(f"Language {lang} has already been imported.", RuntimeWarning)
|
|
else:
|
|
# Register language
|
|
self._loaded_engines[lang] = {}
|
|
# Find whether there is a direct model for translating with other loaded language
|
|
# or we need to use a pivot
|
|
for other_lang in self.loaded_languages:
|
|
forward_engine = self._create_engine(lang, other_lang, download)
|
|
backward_engine = self._create_engine(other_lang, lang, download)
|
|
self._loaded_engines[lang][other_lang] = forward_engine
|
|
self._loaded_engines[other_lang][lang] = backward_engine
|
|
# Register language
|
|
self.loaded_languages.append(lang)
|
|
|
|
def translate(self, source_lang: str, target_lang: str, text: str,
|
|
html: bool = False, alignment: bool = False, quality_scores: bool = False) -> str:
|
|
"""
|
|
Translates a text from a source lang to a target lang.
|
|
|
|
:param source_lang: Language to translate from.
|
|
:param target_lang: Language to translate to.
|
|
:param text: Text to translate.
|
|
:param html: Set to True if the text contains an HTML structure which needs to
|
|
be preserved while translated.
|
|
:param alignment: Toggle for alignment.
|
|
:param quality_scores: Toggle for whether to include the translation's quality scores
|
|
for each word in HTML format.
|
|
:raises:
|
|
ValueError: Either source_lang or target_lang haven't been loaded yet.
|
|
:return: The translated text.
|
|
"""
|
|
if source_lang not in self.loaded_languages:
|
|
raise ValueError(f"Language {source_lang} is not loaded. Use the load() function first.")
|
|
if target_lang not in self.loaded_languages:
|
|
raise ValueError(f"Language {target_lang} is not loaded. Use the load() function first.")
|
|
return self._loaded_engines[source_lang][target_lang].translate(text, html, alignment, quality_scores) |