pybergamot/tests/test_translator.py
2023-11-03 00:36:46 +01:00

62 lines
2.2 KiB
Python

from pytest import raises, warns
import warnings
from bergamot import TranslationModel
from pybergamot import Translator
from pybergamot.engine import DirectBergamotModelEngine, ChainBergamotModelsEngine
def test_init():
Translator()
def test__load_model():
translator = Translator()
with raises(ValueError):
translator._load_model("inexistant")
with raises(EnvironmentError):
translator._load_model("ukr-fin-tiny", download=False)
assert isinstance(translator._load_model("en-fr-tiny"), TranslationModel)
def test__create_engine():
translator = Translator()
with raises(ValueError):
translator._create_engine("inexistant", "en")
with raises(ValueError):
translator._create_engine("en", "inexistant")
assert isinstance(translator._create_engine("en", "fr"), DirectBergamotModelEngine)
assert isinstance(translator._create_engine("fr", "es"), ChainBergamotModelsEngine)
def test_load():
translator = Translator()
with raises(ValueError):
translator.load("inexistant")
translator.load("es")
with warns(RuntimeWarning):
translator.load("es")
assert len(translator.loaded_languages) == 1 and translator.loaded_languages[0] == "es"
assert len(translator._loaded_engines['es']) == 0
translator.load("en")
assert len(translator.loaded_languages) == 2 and translator.loaded_languages[1] == "en"
assert len(translator._loaded_engines['es']) == 1
assert len(translator._loaded_engines['en']) == 1
translator.load("fr")
assert len(translator.loaded_languages) == 3 and translator.loaded_languages[2] == "fr"
assert len(translator._loaded_engines['es']) == 2
assert len(translator._loaded_engines['en']) == 2
assert len(translator._loaded_engines['fr']) == 2
def test_translate():
translator = Translator()
translator.load("en")
translator.load("fr")
with raises(ValueError):
translator.translate("fr", "es", "Salut !")
with raises(ValueError):
translator.translate("es", "fr", "¡Buenos días!")
translator.load("es")
assert type(translator.translate("en", "fr", "Hello!")) == str
assert type(translator.translate("es", "fr", "¡Buenos días!")) == str