Migrate to gitea
This commit is contained in:
152
scripts/collect_samples.py
Normal file
152
scripts/collect_samples.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
import pathlib
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# ==============================================================================
|
||||
# --- CONFIGURATION ---
|
||||
# ==============================================================================
|
||||
|
||||
# --- Paths ---
|
||||
# Try to determine project root relative to this script location
|
||||
try:
|
||||
SCRIPT_DIR = pathlib.Path(__file__).parent
|
||||
ROOT_DIR = SCRIPT_DIR.parent
|
||||
except NameError:
|
||||
SCRIPT_DIR = pathlib.Path.cwd()
|
||||
ROOT_DIR = SCRIPT_DIR.parent
|
||||
|
||||
# Input directory containing the source semua.org files
|
||||
RAW_DATA_DIR = ROOT_DIR / "raw_data"
|
||||
|
||||
# The pattern to match source files
|
||||
FILE_PATTERN = "*raw-wiktextract-data.jsonl"
|
||||
|
||||
# Output directory for the collected samples
|
||||
SAMPLES_DIR = ROOT_DIR / "samples"
|
||||
|
||||
# Final output filename
|
||||
OUTPUT_FILENAME = "combined_samples.jsonl"
|
||||
|
||||
# --- Sampling Options ---
|
||||
|
||||
# How many matching entries to take from EACH source file.
|
||||
SAMPLES_PER_FILE = 2
|
||||
|
||||
# Filter by Language Code.
|
||||
# Set to None to include all languages.
|
||||
# Example: "en", "de", "fr", "no"
|
||||
LANG_FILTER = set()
|
||||
# set()
|
||||
|
||||
# Filter by Part of Speech.
|
||||
# Leave empty set() to include ALL parts of speech.
|
||||
# Example: {"noun", "verb", "adj"}
|
||||
POS_FILTER = {"verb"}
|
||||
|
||||
# Filter to only include entries in their own language (lang_code matches file prefix)
|
||||
OWN_LANG_FILTER = True
|
||||
|
||||
# ==============================================================================
|
||||
# --- END OF CONFIGURATION ---
|
||||
# ==============================================================================
|
||||
|
||||
# Setup simple logging to console
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def collect_samples():
|
||||
# 1. Setup Paths and Directories
|
||||
input_dir = pathlib.Path(RAW_DATA_DIR)
|
||||
output_dir = pathlib.Path(SAMPLES_DIR)
|
||||
output_file = output_dir / OUTPUT_FILENAME
|
||||
|
||||
if not input_dir.exists():
|
||||
logger.error(f"ERROR: Raw data directory not found at: {input_dir}")
|
||||
logger.error("Please ensure your configuration points to the correct folder.")
|
||||
sys.exit(1)
|
||||
|
||||
# Create samples directory if it doesn't exist
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 2. Find all matching input files
|
||||
source_files = list(input_dir.glob(FILE_PATTERN))
|
||||
if not source_files:
|
||||
logger.warning(f"No files matching '{FILE_PATTERN}' found in {input_dir}")
|
||||
sys.exit(0)
|
||||
|
||||
logger.info(f"Found {len(source_files)} source files to sample from.")
|
||||
logger.info(f"Target: {SAMPLES_PER_FILE} samples per file.")
|
||||
logger.info(f"Language Filter: {LANG_FILTER if LANG_FILTER else 'ALL'}")
|
||||
logger.info(f"POS Filter: {POS_FILTER if POS_FILTER else 'ALL'}")
|
||||
logger.info(f"Own Language Filter: {'ENABLED' if OWN_LANG_FILTER else 'DISABLED'}")
|
||||
logger.info("-" * 50)
|
||||
|
||||
total_collected = 0
|
||||
|
||||
# Open the output file once and append samples from all inputs to it
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8') as out_f:
|
||||
|
||||
for src_file in source_files:
|
||||
logger.info(f"Scanning: {src_file.name}...")
|
||||
lang_from_file = src_file.name[:2]
|
||||
file_collected = 0
|
||||
lines_read = 0
|
||||
|
||||
try:
|
||||
with open(src_file, 'r', encoding='utf-8') as in_f:
|
||||
for line in in_f:
|
||||
lines_read += 1
|
||||
|
||||
# Stop reading this file if we have enough samples
|
||||
if file_collected >= SAMPLES_PER_FILE:
|
||||
break
|
||||
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
|
||||
# --- Filtering Logic ---
|
||||
# 1. Language Filter
|
||||
if LANG_FILTER and entry.get('lang_code') != LANG_FILTER:
|
||||
continue
|
||||
|
||||
# 2. POS Filter
|
||||
if POS_FILTER and entry.get('pos') not in POS_FILTER:
|
||||
continue
|
||||
|
||||
# 3. Own Language Filter
|
||||
if OWN_LANG_FILTER and entry.get('lang_code') != lang_from_file:
|
||||
continue
|
||||
|
||||
# --- If it passed filters, save it ---
|
||||
# We write it exactly as it is in the source
|
||||
json.dump(entry, out_f, ensure_ascii=False)
|
||||
out_f.write('\n')
|
||||
file_collected += 1
|
||||
total_collected += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Ignore bad lines in source files during sampling
|
||||
continue
|
||||
|
||||
logger.info(f" -> Collected {file_collected} samples (scanned {lines_read} lines)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ERROR reading {src_file.name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.critical(f"FATAL ERROR writing output file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("-" * 50)
|
||||
logger.info("SAMPLING COMPLETE")
|
||||
logger.info(f"Total entries collected: {total_collected}")
|
||||
logger.info(f"Output saved to: {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
collect_samples()
|
||||
Reference in New Issue
Block a user