In [None]:
VALIDATION_SIZE = 0.3
MAX_TOKENS = 4096
CHUNK_SIZE = 16 #  English sentence average sentence legth: 15~20 / Chinese sentence: 8~14 
LATENT_SIZE = 512

import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

import tensorflow_text as tf_text
import pandas as pd

from tqdm import tqdm

df = pd.read_pickle('merged.pkl')

In [None]:
en_examples, zh_examples = np.array(df.en_news), np.array(df.zh_news.str.replace('ï¿½', ''))
en_examples = np.full(shape = en_examples.shape, fill_value = "[START] ") + en_examples + np.full(shape = en_examples.shape, fill_value = " [END]")
zh_examples = np.full(shape = zh_examples.shape, fill_value = "[START] ") + zh_examples + np.full(shape = zh_examples.shape, fill_value = " [END]")
np.random.seed(42)
isTrain = np.random.rand(df.shape[0]) > VALIDATION_SIZE
en_train = en_examples[isTrain]
zh_train = zh_examples[isTrain]
en_valid = en_examples[~isTrain]
zh_valid = zh_examples[~isTrain]
examples = {}
examples['train'] = tf.data.Dataset.from_tensor_slices((zh_train, en_train))
examples['validation'] = tf.data.Dataset.from_tensor_slices((zh_valid, en_valid))

train_examples, val_examples = examples['train'], examples['validation']

import spacy
import pickle
from scipy.spatial import KDTree
from tqdm.notebook import tqdm
import fasttext
import fasttext.util

class Tokenizer(object):
    def __init__(self, lang):
        if(lang == 'zh'):
            self.tokenizer = spacy.load("zh_core_web_sm")
            self.tokenizer.tokenizer.pkuseg_update_user_dict(["[START]", "[END]"])
            self.lang = 'zh'
            self.fasttext_vectorizer = fasttext.load_model('cc.zh.300.bin')
        elif(lang == 'en'):
            self.tokenizer = spacy.load("en_core_web_sm")
            self.tokenizer.tokenizer.add_special_case("[START]", [{spacy.attrs.ORTH: "[START]"}])
            self.tokenizer.tokenizer.add_special_case("[END]", [{spacy.attrs.ORTH: "[END]"}])
            self.lang = 'en'
            self.fasttext_vectorizer = fasttext.load_model('cc.en.300.bin')
        else:
            self.tokenizer = None
        self.vec_to_text = {}
        self.tree = None
    
    def train(self, docs):
        for doc in tqdm(docs):
            for token in self.tokenizer(doc):
                if(self.lang == 'en'):
                    text = token.text.lower()
                else:
                    text = token.text
                if(token.is_digit):
                    for digit in [*token.text]:
                        if(tuple(self.fasttext_vectorizer.get_word_vector(digit)) not in self.vec_to_text):
                          self.vec_to_text[tuple(self.fasttext_vectorizer.get_word_vector(digit))] = digit
                else:
                    if(tuple(self.fasttext_vectorizer.get_word_vector(text)) not in self.vec_to_text):
                        self.vec_to_text[tuple(self.fasttext_vectorizer.get_word_vector(text))] = text
                
        print()
        print(f'Train summary:')
        print(f'\t{len(self.vec_to_text)} {self.lang} words learned')
        print('Building KDTree...')
        self.tree = KDTree(list(self.vec_to_text.keys()))
        print('KDTree built')
    def tokenize(self, sentence):
        for token in self.tokenizer(sentence):
            print(token.text, end='\t')



en_tok_obj = Tokenizer('en')
zh_tok_obj = Tokenizer('zh')




In [None]:
obj = fasttext.load_model('cc.en.300.bin')

In [None]:
en_tok_obj.train(en_examples)
with open('en_vec_to_text.300-split_digits.pkl', 'wb') as f:
    pickle.dump(en_tok_obj.vec_to_text, f)


In [None]:
zh_tok_obj.train(zh_examples)
with open('zh_vec_to_text.300-split_digits.pkl', 'wb') as f:
    pickle.dump(zh_tok_obj.vec_to_text, f)

In [None]:
with open('zh_vec_to_text.300-split_digits.pkl', 'rb') as f:
    zh_tok_obj.vec_to_text = pickle.load(f)
zh_tok_obj.tree = KDTree(list(zh_tok_obj.vec_to_text.keys()))

In [None]:
with open('en_vec_to_text.300-split_digits.pkl', 'rb') as f:
    en_tok_obj.vec_to_text = pickle.load(f)
en_tok_obj.tree = KDTree(list(en_tok_obj.vec_to_text.keys()))

In [None]:
def en_vectorize(docs):
    forReturn = []
    for doc in docs:
        arr = []
        for token in en_tok_obj.tokenizer(doc):
            if(token.is_digit):
                for digit in [*token.text]:
                    arr.append(en_tok_obj.fasttext_vectorizer.get_word_vector(digit))
            else:
                arr.append(en_tok_obj.fasttext_vectorizer.get_word_vector(token.text.lower()))
        forReturn.append(arr)
    return forReturn

In [None]:
def zh_vectorize(docs):
    forReturn = []
    for doc in docs:
        arr = []
        for token in zh_tok_obj.tokenizer(doc):
            if(token.is_digit):
                for digit in [*token.text]:
                    arr.append(zh_tok_obj.fasttext_vectorizer.get_word_vector(digit))
            else:
                arr.append(zh_tok_obj.fasttext_vectorizer.get_word_vector(token.text))
        forReturn.append(arr)
    return forReturn

In [None]:
def tokenize_pairs(zh, en):
    zh = zh_vectorize(np.char.decode(zh.tolist(), encoding='utf-8').tolist())
    # print(np.array(en).shape)
    zh = tf.ragged.constant(zh, dtype = tf.float16)
    zh = zh[:, :MAX_TOKENS, :]
    zh = zh.to_tensor(shape = zh.shape)

#     print(np.char.decode(zh.tolist(), encoding='utf-8').tolist())
    en = en_vectorize(np.char.decode(en.tolist(), encoding='utf-8').tolist())
    en = tf.ragged.constant(en, dtype = tf.float16)
    en = en[:, :MAX_TOKENS, :]
    en_inputs = en[:, :-1, :].to_tensor()
    en_labels = en[:, 1:, :].to_tensor()
    return zh, en_inputs, en_labels

# def py_wrapper_func(en, zh):
#   x, y, z = tf.numpy_function(tokenize_pairs, [en, zh],(tf.float16, tf.float16, tf.float16))
#   return (x, y), z

def py_wrapper_func_star(zh, en):
  x, y, z = tf.numpy_function(tokenize_pairs, [zh, en],(tf.float16, tf.float16, tf.float16))
  return (x, y), z

BATCH_SIZE = 4

def set_shapes(zh, en_inputs, en_labels):
  return (tf.ensure_shape(zh, [None, None, 300]),\
          tf.ensure_shape(en_inputs, [None, None, 300])),\
          tf.ensure_shape(en_labels, [None, None, 300])

# @tf.function
def make_batches(ds):
  return (
      ds
      .batch(BATCH_SIZE)
      .map(py_wrapper_func_star, tf.data.AUTOTUNE)
      .map(lambda en_zh, zh_labels: set_shapes(en_zh[0], en_zh[1], zh_labels))
      .prefetch(buffer_size=tf.data.AUTOTUNE))

In [None]:
from datetime import datetime

In [None]:
print(datetime.now())
train_batches = make_batches(train_examples)
train_batches.save('ZH_EN-train_batch-300-split_digits-new', compression = 'GZIP')
print(datetime.now())

In [None]:
print(datetime.now())
val_batches = make_batches(val_examples)
val_batches.save('ZH_EN-val_batch-300-split_digits-new', compression = 'GZIP')
print(datetime.now())

In [None]:
import pickle
with open('zh_vec_to_text.300.pkl', 'rb') as f:
  zh_vec_to_text  = pickle.load(f)
from scipy.spatial import KDTree
zh_tree = KDTree(list(zh_vec_to_text.keys()))
with open('zh_str_dict.300-split_digit.pkl', 'rb') as f:
  zh_vec_str_dict = pickle.load(f)

def zh_devectorize(vectors):
  output_sentence = ''
  for vector in vectors:
    arr = zh_tree.data[zh_tree.query(vector)[1]]
    k = '['
    for val in arr:
      k += f'{val:.10f}, '
    k = k[:-1] + ']'
    output_sentence += zh_vec_str_dict[k]
  return output_sentence

In [None]:
def en_devectorize(vectors):
  output_sentence = ''
  for vector in vectors:
    arr = zh_tree.data[zh_tree.query(vector)[1]]
    k = '['
    for val in arr:
      k += f'{val:.10f}, '
    k = k[:-1] + ']'
    output_sentence += zh_vec_str_dict[k]
    output_sentence += ' '
  return output_sentence

In [None]:
for (zh, en), _ in val_batches.take(1):
  pass

In [None]:
zh.shape