## CodeT Test Generation Datasets

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb)

This notebook contains code to parse CodeT test case generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.

Requirements: `requests`

In [1]:
import json
from pathlib import Path
import requests
from typing import List, Tuple

In [2]:
DATA_FILES: List[str] = [
    "HumanEval_for_test_case_generation.jsonl",
    "mbpp_sanitized_for_test_case_generation.jsonl",
]

OUT_FILES: List[str] = [
    "HumanEval_testgen.jsonl",
    "mbpp_testgen.jsonl",
]

FILE_PATHS: List[Path] = [Path(f"data/{data_file}") for data_file in DATA_FILES]

OUT_PATHS: List[Path] = [Path(f"data/augmented/{out_file}") for out_file in OUT_FILES]

In [3]:
def download_file(filename: str):
    url = f"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}"
    response = requests.get(url)
    with open(f"data/{filename}", "wb") as f:
        f.write(response.content)


for filename in DATA_FILES:
    download_file(filename)

In [4]:
def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:
    docstring_start, docstring_end = None, None

    for i, line in enumerate(prompt_lines):
        if not (line.strip().startswith('"""') or line.strip().startswith("'''")):
            continue
        if docstring_start:
            docstring_end = i
            break
        docstring_start = i

    if docstring_end:
        return docstring_start, docstring_end
    raise ValueError(f"No complete docstring found!\n{prompt_lines}")


def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:
    between_lines = prompt_lines[start:end]
    return between_lines

In [5]:
def get_request(sample: dict) -> List[str]:
    prompt = sample["prompt"]
    prompt_lines = prompt.splitlines()

    docstring_start, docstring_end = get_docstring_indices(prompt_lines)

    # Extract prompt
    in_docstring = get_between(prompt_lines, docstring_start, docstring_end)
    if '"""' in in_docstring[0] or "'''" in in_docstring[0]:
        in_docstring[0] = in_docstring[0].replace('"""', "").replace("...", "").strip()
    request = "Write a test for a Python function with the following docstring: " + " ".join(
        [p.strip() for p in in_docstring]
    )

    return request


def get_test_code(sample: dict) -> List[str]:
    test = sample["test"]
    test_lines = test.splitlines()
    start = 0
    for i, line in enumerate(test_lines):
        if "def check(" in line:
            start = i
    return "\n".join(test_lines[start:])

In [6]:
def process_file(file_path: Path, out_path: Path):
    lines = file_path.read_text().splitlines()
    samples = list(map(json.loads, lines))

    output = []
    for sample in samples:
        prompt = get_request(sample)
        test = get_test_code(sample)
        output.append({"prompt": prompt, "solution": test})

    with open(out_path, "w") as f:
        for sample in output:
            f.write(json.dumps(sample))
            f.write("\n")

In [7]:
for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):
    process_file(file_path, out_path)