from argparse import Namespace

from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator


def test_all_datasets():
    qa_base = QA_DATASETS
    summarize_base = SUMMARIZATION_DATASETS
    others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"]
    translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"]

    config = Namespace(cache_dir=".cache")
    for dataset_name in translation + others + summarize_base + qa_base:
        print(dataset_name)
        train, eval = get_one_dataset(config, dataset_name)
        # sanity check
        for idx in range(min(len(train), 1000)):
            train[idx]
        for idx in range(min(len(eval), 1000)):
            eval[idx]


def test_collate_fn():
    from torch.utils.data import ConcatDataset, DataLoader
    from utils import get_tokenizer

    config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
    tokenizer = get_tokenizer(config)
    collate_fn = DialogueDataCollator(tokenizer, max_length=512)
    qa_base = QA_DATASETS
    summarize_base = SUMMARIZATION_DATASETS
    others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
    trains, evals = [], []
    for dataset_name in others + qa_base + summarize_base:
        print(dataset_name)
        train, eval = get_one_dataset(config, dataset_name)
        trains.append(train)
        evals.append(eval)

    dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
    for batch in dataloader:
        print(batch.keys())
        print(tokenizer.decode(batch["input_ids"][0]))
        print("-----")
        print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
        assert batch["targets"].shape[1] <= 512
    dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
    for batch in dataloader:
        assert batch["targets"].shape[1] <= 512
