// Copyright 2018 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #include "sentencepiece_trainer.h" #include #include #include "builder.h" #include "common.h" #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" #include "spec_parser.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_split.h" #include "third_party/absl/strings/string_view.h" #include "third_party/absl/strings/strip.h" #include "trainer_factory.h" #include "util.h" namespace sentencepiece { namespace { static constexpr char kDefaultNormalizerName[] = "nmt_nfkc"; } // namespace // static util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { NormalizerSpec normalizer_spec; return Train(trainer_spec, normalizer_spec, sentence_iterator, serialized_model_proto); } util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { NormalizerSpec denormalizer_spec; return Train(trainer_spec, normalizer_spec, denormalizer_spec, sentence_iterator, serialized_model_proto); } // static util::Status SentencePieceTrainer::Train( const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, const NormalizerSpec &denormalizer_spec, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { auto copied_normalizer_spec = normalizer_spec; RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_normalizer_spec, false)); auto copied_denormalizer_spec = denormalizer_spec; RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_denormalizer_spec, true)); auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec, copied_denormalizer_spec); std::string info = absl::StrCat(PrintProto(trainer_spec, "trainer_spec"), PrintProto(copied_normalizer_spec, "normalizer_spec")); if (!copied_denormalizer_spec.precompiled_charsmap().empty()) { info += PrintProto(copied_denormalizer_spec, "denormalizer_spec"); } else { info += "denormalizer_spec {}"; } LOG(INFO) << "Starts training with : \n" << info; if (serialized_model_proto) { ModelProto model_proto; RETURN_IF_ERROR(trainer->Train(sentence_iterator, &model_proto)); *serialized_model_proto = model_proto.SerializeAsString(); } else { RETURN_IF_ERROR(trainer->Train(sentence_iterator, nullptr)); } return util::OkStatus(); } // static NormalizerSpec SentencePieceTrainer::GetNormalizerSpec(absl::string_view name) { NormalizerSpec spec; spec.set_name(name.data(), name.size()); CHECK_OK(normalizer::Builder::GetPrecompiledCharsMap( spec.name(), spec.mutable_precompiled_charsmap())); return spec; } // static util::Status SentencePieceTrainer::MergeSpecsFromArgs( absl::string_view args, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) { CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null."; CHECK_OR_RETURN(normalizer_spec) << "`normalizer_spec` must not be null."; CHECK_OR_RETURN(denormalizer_spec) << "`denormalizer_spec` must not be null."; if (args.empty()) return util::OkStatus(); std::unordered_map kwargs; for (auto arg : absl::StrSplit(args, " ")) { absl::ConsumePrefix(&arg, "--"); std::string key, value; const auto pos = arg.find('='); if (pos == absl::string_view::npos) { key = std::string(arg); } else { key = std::string(arg.substr(0, pos)); value = std::string(arg.substr(pos + 1)); } kwargs.emplace(key, value); } return MergeSpecsFromArgs(kwargs, trainer_spec, normalizer_spec, denormalizer_spec); } // static util::Status SentencePieceTrainer::MergeSpecsFromArgs( const std::unordered_map &kwargs, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) { CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null."; CHECK_OR_RETURN(normalizer_spec) << "`normalizer_spec` must not be null."; CHECK_OR_RETURN(denormalizer_spec) << "`denormalizer_spec` must not be null."; for (const auto &it : kwargs) { const auto &key = it.first; const auto &value = it.second; // Exceptions. if (key == "normalization_rule_name") { normalizer_spec->set_name(value); continue; } else if (key == "denormalization_rule_tsv") { denormalizer_spec->set_normalization_rule_tsv(value); denormalizer_spec->set_add_dummy_prefix(false); denormalizer_spec->set_remove_extra_whitespaces(false); denormalizer_spec->set_escape_whitespaces(false); continue; } else if (key == "minloglevel") { int v = 0; CHECK_OR_RETURN(absl::SimpleAtoi(value, &v)); logging::SetMinLogLevel(v); continue; } const auto status_train = SetProtoField(key, value, trainer_spec); if (status_train.ok()) continue; if (!util::IsNotFound(status_train)) return status_train; const auto status_norm = SetProtoField(key, value, normalizer_spec); if (status_norm.ok()) continue; if (!util::IsNotFound(status_norm)) return status_norm; // Not found both in trainer_spec and normalizer_spec. if (util::IsNotFound(status_train) && util::IsNotFound(status_norm)) { return status_train; } } return util::OkStatus(); } // static util::Status SentencePieceTrainer::Train(absl::string_view args, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { LOG(INFO) << "Running command: " << args.data(); TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; RETURN_IF_ERROR(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec, &denormalizer_spec)); return Train(trainer_spec, normalizer_spec, denormalizer_spec, sentence_iterator, serialized_model_proto); } // static util::Status SentencePieceTrainer::Train( const std::unordered_map &kwargs, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; NormalizerSpec denormalizer_spec; RETURN_IF_ERROR(MergeSpecsFromArgs(kwargs, &trainer_spec, &normalizer_spec, &denormalizer_spec)); return Train(trainer_spec, normalizer_spec, denormalizer_spec, sentence_iterator, serialized_model_proto); } namespace { class VectorSentenceIterator : public SentenceIterator { public: explicit VectorSentenceIterator(const std::vector &values) : iter_(values.begin()), end_(values.end()) {} virtual ~VectorSentenceIterator() {} virtual bool done() const { return iter_ == end_; } void Next() override { ++iter_; } const std::string &value() const override { return *iter_; } util::Status status() const override { return util::OkStatus(); } private: std::vector::const_iterator iter_; std::vector::const_iterator end_; }; } // namespace // static util::Status SentencePieceTrainer::Train( absl::string_view args, const std::vector &sentences, std::string *serialized_model_proto) { VectorSentenceIterator iter(sentences); return Train(args, &iter, serialized_model_proto); } // static util::Status SentencePieceTrainer::Train( const std::unordered_map &kwargs, const std::vector &sentences, std::string *serialized_model_proto) { VectorSentenceIterator iter(sentences); return Train(kwargs, &iter, serialized_model_proto); } // static util::Status SentencePieceTrainer::PopulateNormalizerSpec( NormalizerSpec *normalizer_spec, bool is_denormalizer) { CHECK_OR_RETURN(normalizer_spec); if (!normalizer_spec->normalization_rule_tsv().empty()) { CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty()) << "precompiled_charsmap is already defined."; normalizer::Builder::CharsMap chars_map; RETURN_IF_ERROR(normalizer::Builder::LoadCharsMap( normalizer_spec->normalization_rule_tsv(), &chars_map)); RETURN_IF_ERROR(normalizer::Builder::CompileCharsMap( chars_map, normalizer_spec->mutable_precompiled_charsmap())); normalizer_spec->set_name("user_defined"); } else if (!is_denormalizer) { if (normalizer_spec->name().empty()) { normalizer_spec->set_name(kDefaultNormalizerName); } if (normalizer_spec->precompiled_charsmap().empty()) { RETURN_IF_ERROR(normalizer::Builder::GetPrecompiledCharsMap( normalizer_spec->name(), normalizer_spec->mutable_precompiled_charsmap())); } } return util::OkStatus(); } // static util::Status SentencePieceTrainer::PopulateModelTypeFromString( absl::string_view type, TrainerSpec *spec) { static const std::unordered_map kModelTypeMap = {{"unigram", TrainerSpec::UNIGRAM}, {"bpe", TrainerSpec::BPE}, {"word", TrainerSpec::WORD}, {"char", TrainerSpec::CHAR}}; const auto it = kModelTypeMap.find(absl::AsciiStrToLower(type)); if (it != kModelTypeMap.end()) { spec->set_model_type(it->second); return util::OkStatus(); } return util::StatusBuilder(util::StatusCode::kInternal, GTL_LOC) << "\"" << type << "\" is not found in TrainerSpec"; } namespace { const pretokenizer::PretokenizerForTrainingInterface *g_pretokenizer = nullptr; } // namespace // static util::Status SentencePieceTrainer::SetPretokenizerForTraining( const pretokenizer::PretokenizerForTrainingInterface *pretokenizer) { g_pretokenizer = pretokenizer; return util::OkStatus(); } // static const pretokenizer::PretokenizerForTrainingInterface * SentencePieceTrainer::GetPretokenizerForTraining() { return g_pretokenizer; } SentencePieceNormalizer::SentencePieceNormalizer() {} SentencePieceNormalizer::~SentencePieceNormalizer() {} util::Status SentencePieceNormalizer::Load( std::unique_ptr model_proto) { model_proto_ = std::move(model_proto); normalizer_ = std::make_unique(model_proto_->normalizer_spec()); CHECK_OR_RETURN(normalizer_); return normalizer_->status(); } util::Status SentencePieceNormalizer::Load(absl::string_view filename) { auto model_proto = std::make_unique(); RETURN_IF_ERROR(io::LoadModelProto(filename, model_proto.get())); return Load(std::move(model_proto)); } util::Status SentencePieceNormalizer::LoadFromSerializedProto( absl::string_view serialized) { auto model_proto = std::make_unique(); CHECK_OR_RETURN( model_proto->ParseFromArray(serialized.data(), serialized.size())); return Load(std::move(model_proto)); } util::Status SentencePieceNormalizer::LoadFromRuleTSV( absl::string_view filename) { auto model_proto = std::make_unique(); auto *spec = model_proto->mutable_normalizer_spec(); spec->set_normalization_rule_tsv(std::string(filename)); RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec)); return Load(std::move(model_proto)); } util::Status SentencePieceNormalizer::LoadFromRuleName(absl::string_view name) { auto model_proto = std::make_unique(); auto *spec = model_proto->mutable_normalizer_spec(); spec->set_name(std::string(name)); RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec)); return Load(std::move(model_proto)); } util::Status SentencePieceNormalizer::Normalize(absl::string_view input, std::string *normalized) const { CHECK_OR_RETURN(normalizer_); std::vector norm_to_orig; return normalizer_->Normalize(input, normalized, &norm_to_orig); } util::Status SentencePieceNormalizer::Normalize( absl::string_view input, std::string *normalized, std::vector *norm_to_orig) const { CHECK_OR_RETURN(normalizer_); return normalizer_->Normalize(input, normalized, norm_to_orig); } std::string SentencePieceNormalizer::Normalize(absl::string_view input) const { std::string normalized; Normalize(input, &normalized).IgnoreError(); return normalized; } NormalizerSpec *SentencePieceNormalizer::mutable_normalizer_spec() const { return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr; } std::string SentencePieceNormalizer::serialized_model_proto() const { return model_proto_ ? model_proto_->SerializeAsString() : ""; } void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm, std::vector *norm_to_orig) { auto utf8_to_unicode_offsets = [](absl::string_view str) { std::vector utf8_to_unicode(str.size() + 1, 0); size_t prev = 0; int ulen = 0; while (!str.empty()) { const size_t mblen = std::max(1, string_util::OneCharLen(str.data())); for (int i = prev; i < prev + mblen; ++i) { utf8_to_unicode[i] = ulen; } ++ulen; prev += mblen; str.remove_prefix(mblen); } utf8_to_unicode[prev] = ulen; return utf8_to_unicode; }; const auto orig_offsets = utf8_to_unicode_offsets(orig); const auto norm_offsets = utf8_to_unicode_offsets(norm); if (orig_offsets.empty() || norm_offsets.empty()) return; std::vector result(norm_offsets.back() + 1, 0); for (int i = 0; i < norm_to_orig->size(); ++i) { result[norm_offsets[i]] = orig_offsets[(*norm_to_orig)[i]]; } *norm_to_orig = std::move(result); } } // namespace sentencepiece