Spaces:
Sleeping
Sleeping
File size: 4,940 Bytes
9826f0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import json
import sqlite3
from pathlib import Path
import pytest
import retrieval_utils
from retrieval_utils import get_recommendations
# Setup Test DB
def _setup_test_db(db_path: str):
# Setup the DB Connection
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
# Create Test Anime Table
cursor.execute(
"""
CREATE TABLE Anime (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
score REAL NOT NULL,
synopsis TEXT
);
"""
)
# Create Test Genre Table
cursor.execute(
"""
CREATE TABLE Genre (
id INTEGER PRIMARY KEY,
genre_name TEXT NOT NULL
);
"""
)
# Create Test AnimeGenre Table
cursor.execute(
"""
CREATE TABLE AnimeGenre (
anime_id INTEGER NOT NULL,
genre_id INTEGER NOT NULL
);
"""
)
# Define new values to be inserted in the Anime Table
anime_rows = [
(1, "Alpha", 9.1, "Alpha synopsis"),
(2, "Beta", 8.7, "Beta synopsis"),
(3, "Gamma", 8.9, "Gamma synopsis"),
(4, "Delta", 7.5, "Delta synopsis"),
]
# Define new values to be inserted in the Genre Table
genre_rows = [
(1, "Action"),
(2, "Drama"),
(3, "Comedy"),
]
# Define new values to be inserted in the AnimeGenre Table
anime_genre_rows = [
(1, 1), (1, 2), # Alpha: Action, Drama (2 matches)
(2, 1), # Beta: Action (1 match)
(3, 2), # Gamma: Drama (1 match)
(4, 3), # Delta: Comedy (0 matches for Action/Drama)
]
# Insert into all Tables the defined new values above
cursor.executemany("INSERT INTO Anime VALUES (?, ?, ?, ?);", anime_rows)
cursor.executemany("INSERT INTO Genre VALUES (?, ?);", genre_rows)
cursor.executemany("INSERT INTO AnimeGenre VALUES (?, ?);", anime_genre_rows)
# Commit all the writes to the DB file
connection.commit()
# Close the cursor and the connection
cursor.close()
connection.close()
def test_get_recommendations_orders_by_match_count_then_score(tmp_path: Path,
monkeypatch: pytest.MonkeyPatch):
# Setup Test Data and Mocks
# Construct a temporary path for the Test DB
db_path = tmp_path / "test.db"
# Setup the test db
_setup_test_db(str(db_path))
# Monkeypatch the DB_PATH variable of retrieval_utils file
monkeypatch.setattr(retrieval_utils, "DB_PATH", str(db_path))
# Execute the Method under Test
result_json = get_recommendations(["Action", "Drama"], limit=3)
result = json.loads(result_json)
# Assert on the results
assert [item["name"] for item in result] == ["Alpha", "Gamma", "Beta"]
assert result[0]["score"] == 9.1
assert "description" in result[0]
def test_get_recommendations_respects_limit(tmp_path: Path,
monkeypatch: pytest.MonkeyPatch):
# Setup Test Data and Mocks
# Construct a temporary path for the Test DB
db_path = tmp_path / "test.db"
# Setup the test db
_setup_test_db(str(db_path))
# Monkeypatch the DB_PATH variable of retrieval_utils file
monkeypatch.setattr(retrieval_utils, "DB_PATH", str(db_path))
# Execute the Method under Test
result_json = get_recommendations(["Action", "Drama"], limit=1)
result = json.loads(result_json)
# Assert on the results
assert len(result) == 1
assert result[0]["name"] == "Alpha"
def test_get_recommendations_single_genre(tmp_path: Path,
monkeypatch: pytest.MonkeyPatch):
# Setup Test Data and Mocks
# Construct a temporary path for the Test DB
db_path = tmp_path / "test.db"
# Setup the test db
_setup_test_db(db_path)
# Monkeypatch the DB_PATH variable of retrieval_utils file
monkeypatch.setattr(retrieval_utils, "DB_PATH", str(db_path))
# Execute the Method under Test
result_json = get_recommendations(["Drama"], limit=5)
result = json.loads(result_json)
# Assert on the results
assert [item["name"] for item in result] == ["Alpha", "Gamma"]
assert all("description" in item for item in result)
def test_get_recommendations_no_genre(tmp_path: Path,
monkeypatch: pytest.MonkeyPatch):
# Setup Test Data and Mocks
# Construct a temporary path for the Test DB
db_path = tmp_path / "test.db"
# Setup the test db
_setup_test_db(db_path)
# Monkeypatch the DB_PATH variable of retrieval_utils file
monkeypatch.setattr(retrieval_utils, "DB_PATH", str(db_path))
# Execute the Method under Test
result_json = get_recommendations([], limit=5)
result = json.loads(result_json)
# Assert on the results
assert len(result) == 0 |