62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
from pytest import raises, warns
|
|
import warnings
|
|
from bergamot import TranslationModel
|
|
from pybergamot import Translator
|
|
from pybergamot.engine import DirectBergamotModelEngine, ChainBergamotModelsEngine
|
|
|
|
|
|
def test_init():
|
|
Translator()
|
|
|
|
|
|
def test__load_model():
|
|
translator = Translator()
|
|
with raises(ValueError):
|
|
translator._load_model("inexistant")
|
|
with raises(EnvironmentError):
|
|
translator._load_model("ukr-fin-tiny", download=False)
|
|
assert isinstance(translator._load_model("en-fr-tiny"), TranslationModel)
|
|
|
|
|
|
def test__create_engine():
|
|
translator = Translator()
|
|
with raises(ValueError):
|
|
translator._create_engine("inexistant", "en")
|
|
with raises(ValueError):
|
|
translator._create_engine("en", "inexistant")
|
|
assert isinstance(translator._create_engine("en", "fr"), DirectBergamotModelEngine)
|
|
assert isinstance(translator._create_engine("fr", "es"), ChainBergamotModelsEngine)
|
|
|
|
|
|
def test_load():
|
|
translator = Translator()
|
|
with raises(ValueError):
|
|
translator.load("inexistant")
|
|
translator.load("es")
|
|
with warns(RuntimeWarning):
|
|
translator.load("es")
|
|
assert len(translator.loaded_languages) == 1 and translator.loaded_languages[0] == "es"
|
|
assert len(translator._loaded_engines['es']) == 0
|
|
translator.load("en")
|
|
assert len(translator.loaded_languages) == 2 and translator.loaded_languages[1] == "en"
|
|
assert len(translator._loaded_engines['es']) == 1
|
|
assert len(translator._loaded_engines['en']) == 1
|
|
translator.load("fr")
|
|
assert len(translator.loaded_languages) == 3 and translator.loaded_languages[2] == "fr"
|
|
assert len(translator._loaded_engines['es']) == 2
|
|
assert len(translator._loaded_engines['en']) == 2
|
|
assert len(translator._loaded_engines['fr']) == 2
|
|
|
|
|
|
def test_translate():
|
|
translator = Translator()
|
|
translator.load("en")
|
|
translator.load("fr")
|
|
with raises(ValueError):
|
|
translator.translate("fr", "es", "Salut !")
|
|
with raises(ValueError):
|
|
translator.translate("es", "fr", "¡Buenos días!")
|
|
translator.load("es")
|
|
assert type(translator.translate("en", "fr", "Hello!")) == str
|
|
assert type(translator.translate("es", "fr", "¡Buenos días!")) == str
|