diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md new file mode 100644 index 00000000..70fca447 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -0,0 +1,23 @@ +## Title + + +## Background + + +## Existing Behavior + + +## Acceptance Criteria +- [] + +## Approach + + +## References + + +## Risks and Rollback + + +## Screenshots / Recordings + \ No newline at end of file diff --git a/.gitignore b/.gitignore index d2cdbd62..984178dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ config/env/* !config/env/*.example -.idea/ \ No newline at end of file +.idea/ +.DS_Store \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 712082e7..b9f417e7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -147,6 +147,16 @@ Each module contains: - Auth endpoints via Djoser: `/auth/` - JWT token lifetime: 60 minutes (access), 1 day (refresh) +#### API Documentation +- Auto-generated using **drf-spectacular** (OpenAPI 3.0) +- **Swagger UI**: `http://localhost:8000/api/docs/` — interactive API explorer +- **ReDoc**: `http://localhost:8000/api/redoc/` — readable reference docs +- **Raw schema**: `http://localhost:8000/api/schema/` +- Configuration in `SPECTACULAR_SETTINGS` in `settings.py` +- Views use `@extend_schema` decorators and `serializer_class` attributes for schema generation +- JWT auth is configured in the schema — use `JWT ` (not `Bearer`) in Swagger UI's Authorize dialog +- To document a new endpoint: add `serializer_class` to the view if it has one, or add `@extend_schema` with `inline_serializer` for views returning raw dicts + #### Key Data Models - **Medication** (`api.views.listMeds.models`) - Medication catalog with benefits/risks - **MedRule** (`api.models.model_medRule`) - Include/Exclude rules for medications based on patient history diff --git a/README.md b/README.md index e5a246b1..fe765910 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,9 @@ for patients with bipolar disorder, helping them shorten their journey to stabil ## Usage -You can view the current build of the website here: [https://balancertestsite.com](https://balancertestsite.com/) +You can view the current build of the website here: [https://balancerproject.org/](https://balancerproject.org/) + +You can view the website in a sandbox here: [https://sandbox.balancerproject.org/](https://sandbox.balancerproject.org/) ## Contributing @@ -31,11 +33,9 @@ Get the code using git by either forking or cloning `CodeForPhilly/balancer-main ``` 2. (Optional) Add your API keys to `config/env/dev.env`: - `OpenAI API` - - `Anthropic API` Tools used for development: 1. `Docker`: Install Docker Desktop -2. `Postman`: Ask to get invited to the Balancer Postman team `balancer_dev` 3. `npm`: In the terminal run 1) 'cd frontend' 2) 'npm install' 3) 'cd ..' ### Running Balancer for development @@ -53,7 +53,7 @@ The application supports connecting to PostgreSQL databases via: See [Database Connection Documentation](./docs/DATABASE_CONNECTION.md) for detailed configuration. **Local Development:** -- Download a sample of papers to upload from [https://balancertestsite.com](https://balancertestsite.com/) +- Download a sample of papers to upload from [https://balancerproject.org/](https://balancerproject.org/) - The email and password of `pgAdmin` are specified in `balancer-main/docker-compose.yml` - The first time you use `pgAdmin` after building the Docker containers you will need to register the server. - The `Host name/address` is the Postgres server service name in the Docker Compose file @@ -73,6 +73,36 @@ df = pd.read_sql(query, engine) #### Django REST - The email and password are set in `server/api/management/commands/createsu.py` +- Backend tests can be run using `pytest` by running the below command inside the running backend container: + +``` +docker compose exec backend pytest api/ -v +``` + +## API Documentation + +Interactive API docs are auto-generated using [drf-spectacular](https://drf-spectacular.readthedocs.io/) and available at: + +- **Swagger UI**: [http://localhost:8000/api/docs/](http://localhost:8000/api/docs/) — interactive explorer with "Try it out" functionality +- **ReDoc**: [http://localhost:8000/api/redoc/](http://localhost:8000/api/redoc/) — clean, readable reference docs +- **Raw schema**: [http://localhost:8000/api/schema/](http://localhost:8000/api/schema/) — OpenAPI 3.0 JSON/YAML + +### Testing authenticated endpoints + +Most endpoints require JWT authentication. To test them in Swagger UI: + +1. **Get a token**: Find the `POST /auth/jwt/create/` endpoint in Swagger UI, click **Try it out**, enter an authorized `email` and `password`, and click **Execute**. Copy the `access` token from the response. +2. **Authorize**: Click the **Authorize** button (lock icon) at the top of the page. Enter `JWT ` in the value field. The prefix must be `JWT`, not `Bearer`. +3. **Test endpoints**: All subsequent requests will include your token. Use **Try it out** on any protected endpoint. +4. **Token refresh**: Access tokens expire after 60 minutes. Use `POST /auth/jwt/refresh/` with your `refresh` token, or repeat step 1. + +### Deployment + +1. Merging your PR into develop automatically triggers a GitHub Release +2. The release triggers a container build workflow that builds and pushes the Docker image +3. [Go to GitHub Packages](https://github.com/CodeForPhilly/balancer-main/pkgs/container/balancer-main%2Fapp) to find the new image tag +4. Update newTag in kustomization.yaml [in the cluster repo](https://github.com/CodeForPhilly/cfp-live-cluster/blob/main/balancer/kustomization.yaml) +5. Open a PR to [cfp-sandbox-cluster](https://github.com/CodeForPhilly/cfp-sandbox-cluster) (or [cfp-live-cluster](https://github.com/CodeForPhilly/cfp-live-cluster)) ## Architecture diff --git a/docker-compose.yml b/docker-compose.yml index 9182cdb6..7a6e7fe9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,11 +18,6 @@ services: networks: app_net: ipv4_address: 192.168.0.2 - healthcheck: - test: ["CMD-SHELL", "pg_isready -U balancer -d balancer_dev"] - interval: 5s - timeout: 5s - retries: 5 pgadmin: image: dpage/pgadmin4 diff --git a/frontend/src/components/Footer/Footer.tsx b/frontend/src/components/Footer/Footer.tsx index d656f5ad..977c59d4 100644 --- a/frontend/src/components/Footer/Footer.tsx +++ b/frontend/src/components/Footer/Footer.tsx @@ -62,11 +62,11 @@ function Footer() { > Leave feedback - - Donate + Support Development = ({ isAuthenticated, isSuperuser }) => { Leave Feedback - Donate + Support Development {isAuthenticated && isSuperuser && (
{
  • - - Donate + Support Development
  • {isAuthenticated && diff --git a/frontend/src/pages/About/About.tsx b/frontend/src/pages/About/About.tsx index c50f6705..9481c74d 100644 --- a/frontend/src/pages/About/About.tsx +++ b/frontend/src/pages/About/About.tsx @@ -77,9 +77,9 @@ function About() {
    - + diff --git a/frontend/src/pages/DocumentManager/UploadFile.tsx b/frontend/src/pages/DocumentManager/UploadFile.tsx index 2ee7b5db..32b727e8 100644 --- a/frontend/src/pages/DocumentManager/UploadFile.tsx +++ b/frontend/src/pages/DocumentManager/UploadFile.tsx @@ -1,5 +1,5 @@ import React, { useState, useRef } from "react"; -import axios from "axios"; +import { adminApi } from "../../api/apiClient"; import TypingAnimation from "../../components/Header/components/TypingAnimation.tsx"; import Layout from "../Layout/Layout.tsx"; @@ -22,14 +22,9 @@ const UploadFile: React.FC = () => { formData.append("file", file); try { - const response = await axios.post( + const response = await adminApi.post( `/api/v1/api/uploadFile`, formData, - { - headers: { - "Content-Type": "multipart/form-data" - }, - } ); console.log("File uploaded successfully", response.data); } catch (error) { diff --git a/frontend/src/pages/Files/ListOfFiles.tsx b/frontend/src/pages/Files/ListOfFiles.tsx index b6fff4ee..37bd459a 100644 --- a/frontend/src/pages/Files/ListOfFiles.tsx +++ b/frontend/src/pages/Files/ListOfFiles.tsx @@ -61,7 +61,7 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ const handleDownload = async (guid: string, fileName: string) => { try { setDownloading(guid); - const { data } = await publicApi.get(`/v1/api/uploadFile/${guid}`, { responseType: 'blob' }); + const { data } = await publicApi.get(`/api/v1/api/uploadFile/${guid}`, { responseType: 'blob' }); const url = window.URL.createObjectURL(new Blob([data])); const link = document.createElement("a"); @@ -82,7 +82,7 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ const handleOpen = async (guid: string) => { try { setOpening(guid); - const { data } = await publicApi.get(`/v1/api/uploadFile/${guid}`, { responseType: 'arraybuffer' }); + const { data } = await publicApi.get(`/api/v1/api/uploadFile/${guid}`, { responseType: 'arraybuffer' }); const file = new Blob([data], { type: 'application/pdf' }); const fileURL = window.URL.createObjectURL(file); diff --git a/server/api/apps.py b/server/api/apps.py index 66656fd2..13977850 100644 --- a/server/api/apps.py +++ b/server/api/apps.py @@ -4,3 +4,38 @@ class ApiConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'api' + + def ready(self): + + try: + import os + import sys + + # ready() runs in every Django process: migrate, test, shell, runserver, etc. + # Only preload the model when we're actually going to serve requests. + # Dev (docker-compose.yml) runs `manage.py runserver 0.0.0.0:8000`. + # Prod (Dockerfile.prod CMD) runs `manage.py runserver 0.0.0.0:8000 --noreload`. + # entrypoint.prod.sh also runs migrate, createsu, and populatedb before exec'ing + # runserver — the guard below correctly skips model loading for those commands too. + if sys.argv[1:2] != ['runserver']: + return + + # Dev's autoreloader spawns two processes: a parent file-watcher and a child + # server. ready() runs in both, but only the child (RUN_MAIN=true) serves + # requests. Skip the parent to avoid loading the model twice on each file change. + # Prod uses --noreload so RUN_MAIN is never set; 'noreload' in sys.argv handles that case. + if os.environ.get('RUN_MAIN') != 'true' and '--noreload' not in sys.argv: + return + + # Note: paraphrase-MiniLM-L6-v2 (~80MB) is downloaded from HuggingFace on first + # use and cached to ~/.cache/torch/sentence_transformers/ inside the container. + # That cache is ephemeral — every container rebuild re-downloads the model unless + # a volume is mounted at that path. + from .services.sentencetTransformer_model import TransformerModel + TransformerModel.get_instance() + except Exception: + # TransformerModel._instance stays None on failure, so the first actual request + # that calls get_instance() will attempt to load the model again. + import logging + logger = logging.getLogger(__name__) + logger.exception("Failed to preload the embedding model at startup") diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py index e35f7965..213519e5 100644 --- a/server/api/services/embedding_services.py +++ b/server/api/services/embedding_services.py @@ -2,6 +2,7 @@ import logging from statistics import median +# Use Q objects to express OR conditions in Django queries from django.db.models import Q from pgvector.django import L2Distance @@ -11,18 +12,17 @@ logger = logging.getLogger(__name__) -def get_closest_embeddings( - user, message_data, document_name=None, guid=None, num_results=10 -): + +def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10): """ - Find the closest embeddings to a given message for a specific user. + Build an unevaluated QuerySet for the closest embeddings. Parameters ---------- user : User The user whose uploaded documents will be searched - message_data : str - The input message to find similar embeddings for + embedding_vector : array-like + Pre-computed embedding vector to compare against document_name : str, optional Filter results to a specific document name guid : str, optional @@ -32,59 +32,52 @@ def get_closest_embeddings( Returns ------- - list[dict] - List of dictionaries containing embedding results with keys: - - name: document name - - text: embedded text content - - page_number: page number in source document - - chunk_number: chunk number within the document - - distance: L2 distance from query embedding - - file_id: GUID of the source file + QuerySet + Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results """ - - encoding_start = time.time() - transformerModel = TransformerModel.get_instance().model - embedding_message = transformerModel.encode(message_data) - encoding_time = time.time() - encoding_start - - db_query_start = time.time() - # Django QuerySets are lazily evaluated if user.is_authenticated: # User sees their own files + files uploaded by superusers - closest_embeddings_query = ( - Embeddings.objects.filter( - Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) - ) - .annotate( - distance=L2Distance("embedding_sentence_transformers", embedding_message) - ) - .order_by("distance") + queryset = Embeddings.objects.filter( + Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) ) else: # Unauthenticated users only see superuser-uploaded files - closest_embeddings_query = ( - Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) - .annotate( - distance=L2Distance("embedding_sentence_transformers", embedding_message) - ) - .order_by("distance") - ) + queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) + + queryset = ( + queryset + .annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector)) + .order_by("distance") + ) # Filtering to a document GUID takes precedence over a document name if guid: - closest_embeddings_query = closest_embeddings_query.filter( - upload_file__guid=guid - ) + queryset = queryset.filter(upload_file__guid=guid) elif document_name: - closest_embeddings_query = closest_embeddings_query.filter(name=document_name) + queryset = queryset.filter(name=document_name) # Slicing is equivalent to SQL's LIMIT clause - closest_embeddings_query = closest_embeddings_query[:num_results] + return queryset[:num_results] + + +def evaluate_query(queryset): + """ + Evaluate a QuerySet and return a list of result dicts. + + Parameters + ---------- + queryset : iterable + Iterable of Embeddings objects (or any objects with the expected attributes) + Returns + ------- + list[dict] + List of dicts with keys: name, text, page_number, chunk_number, distance, file_id + """ # Iterating evaluates the QuerySet and hits the database # TODO: Research improving the query evaluation performance - results = [ + return [ { "name": obj.name, "text": obj.text, @@ -93,13 +86,36 @@ def get_closest_embeddings( "distance": obj.distance, "file_id": obj.upload_file.guid if obj.upload_file else None, } - for obj in closest_embeddings_query + for obj in queryset ] - db_query_time = time.time() - db_query_start +def log_usage( + results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time +): + """ + Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. + + Parameters + ---------- + results : list[dict] + The search results, each containing a "distance" key + message_data : str + The original search query text + user : User + The user who performed the search + guid : str or None + Document GUID filter used in the search + document_name : str or None + Document name filter used in the search + num_results : int + Number of results requested + encoding_time : float + Time in seconds to encode the query + db_query_time : float + Time in seconds for the database query + """ try: - # Handle user having no uploaded docs or doc filtering returning no matches if results: distances = [r["distance"] for r in results] SemanticSearchUsage.objects.create( @@ -113,11 +129,10 @@ def get_closest_embeddings( num_results_returned=len(results), max_distance=max(distances), median_distance=median(distances), - min_distance=min(distances) + min_distance=min(distances), ) else: logger.warning("Semantic search returned no results") - SemanticSearchUsage.objects.create( query_text=message_data, user=user if (user and user.is_authenticated) else None, @@ -129,9 +144,58 @@ def get_closest_embeddings( num_results_returned=0, max_distance=None, median_distance=None, - min_distance=None + min_distance=None, ) - except Exception as e: - logger.error(f"Failed to create semantic search usage database record: {e}") + except Exception: + logger.exception("Failed to create semantic search usage database record") + + +def get_closest_embeddings( + user, message_data, document_name=None, guid=None, num_results=10 +): + """ + Find the closest embeddings to a given message for a specific user. + + Parameters + ---------- + user : User + The user whose uploaded documents will be searched + message_data : str + The input message to find similar embeddings for + document_name : str, optional + Filter results to a specific document name + guid : str, optional + Filter results to a specific document GUID (takes precedence over document_name) + num_results : int, default 10 + Maximum number of results to return + + Returns + ------- + list[dict] + List of dictionaries containing embedding results with keys: + - name: document name + - text: embedded text content + - page_number: page number in source document + - chunk_number: chunk number within the document + - distance: L2 distance from query embedding + - file_id: GUID of the source file + + Notes + ----- + Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. + """ + encoding_start = time.time() + model = TransformerModel.get_instance().model + embedding_vector = model.encode(message_data) + encoding_time = time.time() - encoding_start + + db_query_start = time.time() + queryset = build_query(user, embedding_vector, document_name, guid, num_results) + results = evaluate_query(queryset) + db_query_time = time.time() - db_query_start + + log_usage( + results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time + ) return results diff --git a/server/api/services/test_embedding_services.py b/server/api/services/test_embedding_services.py new file mode 100644 index 00000000..e43c0d74 --- /dev/null +++ b/server/api/services/test_embedding_services.py @@ -0,0 +1,400 @@ +from unittest.mock import MagicMock, patch + +from django.db.models import Q +from pgvector.django import L2Distance + +from api.services.embedding_services import ( + build_query, + evaluate_query, + get_closest_embeddings, + log_usage, +) + +# Each function is tested one responsibility at a time. One test for the whole +# function collapses all responsibilities into a single assertion block — when +# it fails you know something is broken but not which responsibility. You have +# to debug to find out. + +# --------------------------------------------------------------------------- +# build_query tests +# +# build_query is responsible for access control, annotate/order, document filter +# and slicing and only constructs a lazy Django QuerySet without evaluating it +# +# We can test build_query by patching Embeddings.objects and inspecting which +# methods and arguments were called on Embeddings.objects +# --------------------------------------------------------------------------- + +# Only forwarded to L2Distance +EMBEDDING_VECTOR = [0.1, 0.2, 0.3] + +# Test authenticated/unauthenticated user access control + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_authenticated_uses_or_filter(mock_objects): + # An authenticated user should see their own files OR files uploaded by a + # superuser. The initial filter must use an OR-connected Q expression. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR) + + # Q objects support equality comparison in pure Python — no DB needed. + expected_q = Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) + actual_q = mock_objects.filter.call_args.args[0] + assert actual_q == expected_q + + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_unauthenticated_uses_superuser_only_filter(mock_objects): + # An unauthenticated user may only see files uploaded by superusers. + # The source uses a plain kwarg here (not a positional Q object), so the + # value lives in call_args.kwargs, not call_args.args. + user = MagicMock(is_authenticated=False) + + build_query(user, EMBEDDING_VECTOR) + + assert mock_objects.filter.call_args.kwargs == {"upload_file__uploaded_by__is_superuser": True} + +# Test application of annotate and order_by + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_annotates_and_orders_by_distance(mock_objects): + # Regardless of other arguments, annotate(distance=L2Distance(...)) and + # order_by("distance") must always be applied to the queryset. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR) + + # Retrieve the mock chain that .filter() returned, then check its methods. + filtered_qs = mock_objects.filter.return_value + filtered_qs.annotate.assert_called_once() + filtered_qs.annotate.return_value.order_by.assert_called_once_with("distance") + + # L2Distance is a Django Func subclass, which implements __eq__ by comparing + # class and source expressions — so we can assert the exact field name and + # vector without patching L2Distance itself. + actual_distance_expr = filtered_qs.annotate.call_args.kwargs["distance"] + assert actual_distance_expr == L2Distance("embedding_sentence_transformers", EMBEDDING_VECTOR) + +# Test guid-over-document precedence logic + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_no_document_filter_when_both_none(mock_objects): + # When neither guid nor document_name is provided, only the access-control + # filter should fire — no secondary filter call for a document. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, document_name=None, guid=None) + + # Exactly one filter call: the auth/access-control filter. + assert mock_objects.filter.call_count == 1 + + + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_guid_takes_precedence_over_document_name(mock_objects): + # When both guid and document_name are provided, the guid branch runs and + # the document_name branch is skipped entirely. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, guid="abc-123", document_name="study.pdf") + + # The auth filter fires on mock_objects.filter (call_count == 1). + # The document filter fires on the chained ordered_qs.filter — a different + # mock object — so mock_objects.filter.call_count stays at 1. + assert mock_objects.filter.call_count == 1 + + # The document filter must use upload_file__guid, not name, and must be + # called exactly once (confirming document_name branch was skipped). + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + ordered_qs.filter.assert_called_once_with(upload_file__guid="abc-123") + + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_guid_filter_applied(mock_objects): + # When only guid is given, a second filter on upload_file__guid is applied. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, guid="doc-guid-456") + + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + ordered_qs.filter.assert_called_once_with(upload_file__guid="doc-guid-456") + + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_document_name_filter_applied(mock_objects): + # When only document_name is given (guid is None), a second filter on + # name is applied instead of upload_file__guid. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, document_name="study.pdf", guid=None) + + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + ordered_qs.filter.assert_called_once_with(name="study.pdf") + + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_empty_string_guid_falls_back_to_document_name(mock_objects): + # An empty-string guid is falsy in Python, so it should not trigger the + # guid branch. The document_name filter should fire instead. This guards + # against callers passing guid="" from an unset form field. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, guid="", document_name="fallback.pdf") + + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + ordered_qs.filter.assert_called_once_with(name="fallback.pdf") + +# Cover LIMIT slicing + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_respects_num_results(mock_objects): + # num_results controls the SQL LIMIT via queryset slicing. Verify that a + # non-default value propagates correctly to the __getitem__ call. + user = MagicMock(is_authenticated=True) + + build_query(user, EMBEDDING_VECTOR, num_results=5) + + # Django translates qs[:5] into qs.__getitem__(slice(None, 5, None)). + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + ordered_qs.__getitem__.assert_called_once_with(slice(None, 5, None)) + +@patch("api.services.embedding_services.Embeddings.objects") +def test_build_query_returns_unevaluated_queryset(mock_objects): + # build_query must NOT evaluate the queryset (no list(), no iteration). + # The return value should be the mock produced by the final __getitem__ call. + user = MagicMock(is_authenticated=True) + + result = build_query(user, EMBEDDING_VECTOR) + + ordered_qs = mock_objects.filter.return_value.annotate.return_value.order_by.return_value + assert result is ordered_qs.__getitem__.return_value + assert not isinstance(result, list) + + +# --------------------------------------------------------------------------- +# evaluate_query tests +# +# evaluate_query is responsible for iterating the queryset and mapping each +# Embeddings object's attributes to a result dict, including the rename +# page_num -> page_number and the None-safe file_id lookup +# +# We can test evaluate_query by passing plain MagicMock objects directly as +# the iterable and asserting on the shape and values of the returned list +# --------------------------------------------------------------------------- + +def test_evaluate_query_empty_queryset(): + # An empty iterable should return an empty list, not raise an exception. + assert evaluate_query([]) == [] + + +def test_evaluate_query_maps_fields(): + # Verify that each Embeddings model attribute is mapped to the correct + # output dict key. Note the rename: obj.page_num -> result["page_number"]. + obj = MagicMock() + obj.name = "doc.pdf" + obj.text = "some text" + obj.page_num = 3 + obj.chunk_number = 1 + obj.distance = 0.42 + obj.upload_file.guid = "abc-123" + + results = evaluate_query([obj]) + + assert results == [ + { + "name": "doc.pdf", + "text": "some text", + "page_number": 3, + "chunk_number": 1, + "distance": 0.42, + "file_id": "abc-123", + } + ] + + +def test_evaluate_query_none_upload_file(): + # When upload_file is None, file_id must be None rather than raising + # an AttributeError on None.guid. + obj = MagicMock() + obj.name = "doc.pdf" + obj.text = "some text" + obj.page_num = 1 + obj.chunk_number = 0 + obj.distance = 1.0 + obj.upload_file = None + + results = evaluate_query([obj]) + + assert results[0]["file_id"] is None + +# --------------------------------------------------------------------------- +# log_usage tests +# +# log_usage is responsible for computing distance stats, storing the correct +# user (None for unauthenticated), handling empty results, and swallowing +# exceptions so search is never interrupted +# +# We can test log_usage by patching SemanticSearchUsage.objects.create and +# inspecting the keyword arguments it was called with +# --------------------------------------------------------------------------- + +@patch("api.services.embedding_services.SemanticSearchUsage.objects.create") +def test_log_usage_empty_results(mock_create): + # Empty results hits the else branch. The record should still be created + # with num_results_returned=0 and all distance fields set to None. + user = MagicMock(is_authenticated=True) + + log_usage( + [], + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + mock_create.assert_called_once() + kwargs = mock_create.call_args.kwargs + assert kwargs["num_results_returned"] == 0 + assert kwargs["max_distance"] is None + assert kwargs["median_distance"] is None + assert kwargs["min_distance"] is None + + +@patch("api.services.embedding_services.SemanticSearchUsage.objects.create") +def test_log_usage_unauthenticated_user_stored_as_none(mock_create): + # An unauthenticated user should be stored as None in the DB record, not as + # the user object itself, so the FK constraint is not violated. + user = MagicMock(is_authenticated=False) + + log_usage( + [{"distance": 1.0}], + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + kwargs = mock_create.call_args.kwargs + assert kwargs["user"] is None + + +@patch("api.services.embedding_services.SemanticSearchUsage.objects.create") +def test_log_usage_none_user_stored_as_none(mock_create): + # Passing user=None directly (e.g. from an anonymous request) should also + # store None — the expression `user if (user and user.is_authenticated)` + # short-circuits on the falsy None before accessing .is_authenticated. + log_usage( + [{"distance": 1.0}], + message_data="test query", + user=None, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + kwargs = mock_create.call_args.kwargs + assert kwargs["user"] is None + + +@patch("api.services.embedding_services.SemanticSearchUsage.objects.create") +def test_log_usage_computes_distance_stats(mock_create): + # Verify min, max, and median are computed correctly from the distance + # values in the results list and forwarded to the DB record. + results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}] + user = MagicMock(is_authenticated=True) + + log_usage( + results, + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + mock_create.assert_called_once() + kwargs = mock_create.call_args.kwargs + assert kwargs["min_distance"] == 1.0 + assert kwargs["max_distance"] == 3.0 + assert kwargs["median_distance"] == 2.0 + assert kwargs["num_results_returned"] == 3 + + +@patch( + "api.services.embedding_services.SemanticSearchUsage.objects.create", + side_effect=Exception("DB error"), +) +def test_log_usage_swallows_exceptions(mock_create): + # log_usage must not propagate exceptions — a logging failure should never + # interrupt the caller's search flow. + # pytest fails the test if it catches unhandled Exception + results = [{"distance": 1.0}] + user = MagicMock(is_authenticated=True) + + log_usage( + results, + message_data="test query", + user=user, + guid=None, + document_name=None, + num_results=10, + encoding_time=0.1, + db_query_time=0.2, + ) + + +# --------------------------------------------------------------------------- +# get_closest_embeddings tests +# +# get_closest_embeddings is responsible for wiring together encode, +# build_query, evaluate_query, and log_usage and returning the results +# +# We can test get_closest_embeddings by patching all four collaborators and +# asserting that each is called with the correct arguments in the correct order +# --------------------------------------------------------------------------- + +@patch("api.services.embedding_services.log_usage") +@patch("api.services.embedding_services.evaluate_query") +@patch("api.services.embedding_services.build_query") +@patch("api.services.embedding_services.TransformerModel") +def test_get_closest_embeddings_wiring(mock_transformer, mock_build, mock_evaluate, mock_log): + # Smoke test verifying that get_closest_embeddings correctly wires together + # encode → build_query → evaluate_query → log_usage and returns the results. + user = MagicMock(is_authenticated=True) + + # Simulate the model encoding the message to a vector. + fake_vector = [0.1, 0.2, 0.3] + mock_transformer.get_instance.return_value.model.encode.return_value = fake_vector + + # build_query returns a queryset; evaluate_query turns it into a results list. + fake_queryset = MagicMock() + mock_build.return_value = fake_queryset + fake_results = [{"name": "doc.pdf", "distance": 0.5}] + mock_evaluate.return_value = fake_results + + result = get_closest_embeddings(user, "some query", document_name="doc.pdf", guid=None, num_results=5) + + # The encoded vector must be forwarded to build_query. + mock_build.assert_called_once_with(user, fake_vector, "doc.pdf", None, 5) + + # evaluate_query must receive the queryset that build_query returned. + mock_evaluate.assert_called_once_with(fake_queryset) + + # log_usage must be called with the results and original parameters. + mock_log.assert_called_once() + log_kwargs = mock_log.call_args.args + assert log_kwargs[0] is fake_results + + # The function must return evaluate_query's result unchanged. + assert result is fake_results diff --git a/server/api/views/ai_promptStorage/views.py b/server/api/views/ai_promptStorage/views.py index 7354feb3..cc50f22e 100644 --- a/server/api/views/ai_promptStorage/views.py +++ b/server/api/views/ai_promptStorage/views.py @@ -1,10 +1,12 @@ from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response +from drf_spectacular.utils import extend_schema from .models import AI_PromptStorage from .serializers import AI_PromptStorageSerializer +@extend_schema(request=AI_PromptStorageSerializer, responses={201: AI_PromptStorageSerializer}) @api_view(['POST']) # @permission_classes([IsAuthenticated]) def store_prompt(request): @@ -21,6 +23,7 @@ def store_prompt(request): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +@extend_schema(responses={200: AI_PromptStorageSerializer(many=True)}) @api_view(['GET']) def get_all_prompts(request): """ diff --git a/server/api/views/ai_settings/views.py b/server/api/views/ai_settings/views.py index 349b9fd9..9ee6aad7 100644 --- a/server/api/views/ai_settings/views.py +++ b/server/api/views/ai_settings/views.py @@ -2,10 +2,12 @@ from rest_framework.decorators import api_view, permission_classes from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response +from drf_spectacular.utils import extend_schema from .models import AI_Settings from .serializers import AISettingsSerializer +@extend_schema(request=AISettingsSerializer, responses={200: AISettingsSerializer(many=True), 201: AISettingsSerializer}) @api_view(['GET', 'POST']) @permission_classes([IsAuthenticated]) def settings_view(request): diff --git a/server/api/views/assistant/sanitizer.py b/server/api/views/assistant/sanitizer.py index bdbbc77f..fd851df6 100644 --- a/server/api/views/assistant/sanitizer.py +++ b/server/api/views/assistant/sanitizer.py @@ -1,26 +1,76 @@ import re import logging + logger = logging.getLogger(__name__) def sanitize_input(user_input:str) -> str: """ Sanitize user input to prevent injection attacks and remove unwanted characters. + Args: user_input (str): The raw input string from the user. + Returns: str: The sanitized input string. """ try: - # Remove any script tags - sanitized = re.sub(r'.*?', '', user_input, flags=re.IGNORECASE) - # Remove any HTML tags + sanitized = user_input + + # Remove any style tags + sanitized = re.sub(r'.*?', '', sanitized, flags=re.IGNORECASE) + + # Remove any HTML/script tags sanitized = re.sub(r'<.*?>', '', sanitized) + + # Remove Phone Numbers + sanitized = re.sub(r'\+?\d[\d -]{8,}\d', '[Phone Number]', sanitized) + + # Remove Email Addresses + sanitized = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', '[Email Address]', sanitized) + + # Remove Medical Record Numbers (simple pattern) + sanitized = re.sub(r'\bMRN[:\s]*\d+\b', '[Medical Record Number]', sanitized, flags=re.IGNORECASE) + + # Normalize pronouns + sanitized = normalize_pronouns(sanitized) + # Escape special characters - sanitized = re.sub(r'["\'\\]', '', sanitized) + sanitized = re.sub(r'\s+', '', sanitized) + # Limit length to prevent buffer overflow attacks - max_length = 1000 + max_length = 5000 if len(sanitized) > max_length: sanitized = sanitized[:max_length] + return sanitized.strip() except Exception as e: logger.error(f"Error sanitizing input: {e}") - return "" \ No newline at end of file + return "" + +def normalize_pronouns(text:str) -> str: + """ + Normalize first and second person pronouns to third person clinical language. + + Converts patient centric pronouns to a more neutral form. + Args: + text (str): The input text containing pronouns. + Returns: + str: The text with normalized pronouns. + """ + # Normalize first person possessives: I, me, my, mine -> the patient + text = re.sub(r'\bMy\b', 'The patient\'s', text) + text = re.sub(r'\bmy\b', 'the patient\'s', text) + + # First person subject: I -> the patient + text = re.sub(r'\bI\b', 'the patient', text) + + # First person object: me -> the patient + text = re.sub(r'\bme\b', 'the patient', text) + + # First person reflexive: myself -> the patient + text = re.sub(r'\bmyself\b', 'the patient', text) + + # Second person: you, your -> the clinician + text = re.sub(r'\bYour\b', 'the clinician', text) + return text + + diff --git a/server/api/views/assistant/views.py b/server/api/views/assistant/views.py index f31ab475..e3e8d6f7 100644 --- a/server/api/views/assistant/views.py +++ b/server/api/views/assistant/views.py @@ -10,6 +10,8 @@ from rest_framework.permissions import AllowAny from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from drf_spectacular.utils import extend_schema, inline_serializer +from rest_framework import serializers as drf_serializers from openai import OpenAI @@ -113,6 +115,21 @@ def invoke_functions_from_response( class Assistant(APIView): permission_classes = [AllowAny] + @extend_schema( + request=inline_serializer(name='AssistantRequest', fields={ + 'message': drf_serializers.CharField(help_text='User message to send to the assistant'), + 'previous_response_id': drf_serializers.CharField(required=False, allow_null=True, help_text='ID of previous response for conversation continuity'), + }), + responses={ + 200: inline_serializer(name='AssistantResponse', fields={ + 'response_output_text': drf_serializers.CharField(), + 'final_response_id': drf_serializers.CharField(), + }), + 500: inline_serializer(name='AssistantError', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def post(self, request): try: user = request.user diff --git a/server/api/views/conversations/views.py b/server/api/views/conversations/views.py index eeb68809..de927cf1 100644 --- a/server/api/views/conversations/views.py +++ b/server/api/views/conversations/views.py @@ -16,6 +16,8 @@ from .models import Conversation, Message from .serializers import ConversationSerializer from ...services.tools.tools import tools, execute_tool +from drf_spectacular.utils import extend_schema, inline_serializer +from rest_framework import serializers as drf_serializers @csrf_exempt @@ -95,6 +97,21 @@ def destroy(self, request, *args, **kwargs): self.perform_destroy(instance) return Response(status=status.HTTP_204_NO_CONTENT) + @extend_schema( + request=inline_serializer(name='ContinueConversationRequest', fields={ + 'message': drf_serializers.CharField(help_text='User message to continue the conversation'), + 'page_context': drf_serializers.CharField(required=False, help_text='Optional page context'), + }), + responses={ + 200: inline_serializer(name='ContinueConversationResponse', fields={ + 'response': drf_serializers.CharField(), + 'title': drf_serializers.CharField(), + }), + 400: inline_serializer(name='ContinueConversationBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) @action(detail=True, methods=['post']) def continue_conversation(self, request, pk=None): conversation = self.get_object() @@ -123,6 +140,20 @@ def continue_conversation(self, request, pk=None): return Response({"response": chatgpt_response, "title": conversation.title}) + @extend_schema( + request=inline_serializer(name='UpdateTitleRequest', fields={ + 'title': drf_serializers.CharField(help_text='New conversation title'), + }), + responses={ + 200: inline_serializer(name='UpdateTitleResponse', fields={ + 'status': drf_serializers.CharField(), + 'title': drf_serializers.CharField(), + }), + 400: inline_serializer(name='UpdateTitleBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) @action(detail=True, methods=['patch']) def update_title(self, request, pk=None): conversation = self.get_object() diff --git a/server/api/views/embeddings/embeddingsView.py b/server/api/views/embeddings/embeddingsView.py index d0bdd8ca..ebcf0774 100644 --- a/server/api/views/embeddings/embeddingsView.py +++ b/server/api/views/embeddings/embeddingsView.py @@ -1,8 +1,9 @@ from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from rest_framework import status +from rest_framework import status, serializers as drf_serializers from django.http import StreamingHttpResponse +from drf_spectacular.utils import extend_schema, inline_serializer, OpenApiParameter from ...services.embedding_services import get_closest_embeddings from ...services.conversions_services import convert_uuids from ...services.openai_services import openAIServices @@ -15,6 +16,26 @@ class AskEmbeddingsAPIView(APIView): permission_classes = [IsAuthenticated] + @extend_schema( + parameters=[ + OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=False, description='Optional file GUID to filter embeddings'), + OpenApiParameter(name='stream', type=bool, location=OpenApiParameter.QUERY, required=False, description='Enable streaming response'), + ], + request=inline_serializer(name='AskEmbeddingsRequest', fields={ + 'message': drf_serializers.CharField(help_text='Question to ask against embedded documents'), + }), + responses={ + 200: inline_serializer(name='AskEmbeddingsResponse', fields={ + 'question': drf_serializers.CharField(), + 'llm_response': drf_serializers.CharField(), + 'embeddings_info': drf_serializers.CharField(), + 'sent_to_llm': drf_serializers.CharField(), + }), + 400: inline_serializer(name='AskEmbeddingsBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def post(self, request, *args, **kwargs): try: user = request.user diff --git a/server/api/views/feedback/views.py b/server/api/views/feedback/views.py index d0f0e1da..424e0758 100644 --- a/server/api/views/feedback/views.py +++ b/server/api/views/feedback/views.py @@ -9,6 +9,7 @@ class FeedbackView(APIView): permission_classes = [AllowAny] + serializer_class = FeedbackSerializer def post(self, request, *args, **kwargs): serializer = FeedbackSerializer(data=request.data) diff --git a/server/api/views/listMeds/views.py b/server/api/views/listMeds/views.py index fcd0edf2..1b199a7e 100644 --- a/server/api/views/listMeds/views.py +++ b/server/api/views/listMeds/views.py @@ -1,7 +1,8 @@ -from rest_framework import status +from rest_framework import status, serializers as drf_serializers from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView +from drf_spectacular.utils import extend_schema, inline_serializer from .models import Diagnosis, Medication, Suggestion from .serializers import MedicationSerializer @@ -24,6 +25,33 @@ class GetMedication(APIView): permission_classes = [AllowAny] + @extend_schema( + request=inline_serializer( + name='GetMedicationRequest', + fields={ + 'state': drf_serializers.CharField(help_text='Diagnosis state, e.g. "depressed", "manic"'), + 'suicideHistory': drf_serializers.BooleanField(default=False), + 'kidneyHistory': drf_serializers.BooleanField(default=False), + 'liverHistory': drf_serializers.BooleanField(default=False), + 'bloodPressureHistory': drf_serializers.BooleanField(default=False), + 'weightGainConcern': drf_serializers.BooleanField(default=False), + 'priorMedications': drf_serializers.CharField(required=False, default='', help_text='Comma-separated medication names'), + } + ), + responses={ + 200: inline_serializer( + name='GetMedicationResponse', + fields={ + 'first': drf_serializers.ListField(child=drf_serializers.DictField()), + 'second': drf_serializers.ListField(child=drf_serializers.DictField()), + 'third': drf_serializers.ListField(child=drf_serializers.DictField()), + } + ), + 404: inline_serializer(name='GetMedicationNotFound', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def post(self, request): data = request.data state_query = data.get('state', '') @@ -75,6 +103,7 @@ def post(self, request): class ListOrDetailMedication(APIView): permission_classes = [AllowAny] + serializer_class = MedicationSerializer def get(self, request): name_query = request.query_params.get('name', None) @@ -98,6 +127,7 @@ class AddMedication(APIView): """ API endpoint to add a medication to the database with its risks and benefits. """ + serializer_class = MedicationSerializer def post(self, request): data = request.data @@ -129,6 +159,22 @@ class DeleteMedication(APIView): API endpoint to delete medication if medication in database. """ + @extend_schema( + request=inline_serializer(name='DeleteMedicationRequest', fields={ + 'name': drf_serializers.CharField(), + }), + responses={ + 200: inline_serializer(name='DeleteMedicationSuccess', fields={ + 'success': drf_serializers.CharField(), + }), + 400: inline_serializer(name='DeleteMedicationBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + 404: inline_serializer(name='DeleteMedicationNotFound', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def delete(self, request): data = request.data name = data.get('name', '').strip() diff --git a/server/api/views/medRules/serializers.py b/server/api/views/medRules/serializers.py index df5e3663..e0d7d3f3 100644 --- a/server/api/views/medRules/serializers.py +++ b/server/api/views/medRules/serializers.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from drf_spectacular.utils import extend_schema_field from ...models.model_medRule import MedRule, MedRuleSource from ..listMeds.serializers import MedicationSerializer from ...models.model_embeddings import Embeddings @@ -30,6 +31,7 @@ class Meta: "medication_sources", ] + @extend_schema_field(MedicationWithSourcesSerializer(many=True)) def get_medication_sources(self, obj): medrule_sources = MedRuleSource.objects.filter(medrule=obj).select_related( "medication", "embedding" diff --git a/server/api/views/medRules/views.py b/server/api/views/medRules/views.py index 2fae140b..2f80f8f3 100644 --- a/server/api/views/medRules/views.py +++ b/server/api/views/medRules/views.py @@ -1,9 +1,10 @@ from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from rest_framework import status +from rest_framework import status, serializers as drf_serializers from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from drf_spectacular.utils import extend_schema, inline_serializer from ...models.model_medRule import MedRule from .serializers import MedRuleSerializer # You'll need to create this from ..listMeds.models import Medication @@ -13,6 +14,7 @@ @method_decorator(csrf_exempt, name='dispatch') class MedRules(APIView): permission_classes = [IsAuthenticated] + serializer_class = MedRuleSerializer def get(self, request, format=None): # Get all med rules @@ -29,6 +31,27 @@ def get(self, request, format=None): return Response(data, status=status.HTTP_200_OK) + @extend_schema( + request=inline_serializer(name='MedRuleCreateRequest', fields={ + 'rule_type': drf_serializers.CharField(help_text='INCLUDE or EXCLUDE'), + 'history_type': drf_serializers.CharField(help_text='e.g. DIAGNOSIS_DEPRESSED, DIAGNOSIS_MANIC'), + 'reason': drf_serializers.CharField(), + 'label': drf_serializers.CharField(), + 'explanation': drf_serializers.CharField(), + 'medication_names': drf_serializers.ListField(child=drf_serializers.CharField()), + 'chunk_ids': drf_serializers.ListField(child=drf_serializers.IntegerField()), + 'file_guid': drf_serializers.CharField(), + }), + responses={ + 201: MedRuleSerializer, + 400: inline_serializer(name='MedRuleCreateBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + 404: inline_serializer(name='MedRuleCreateNotFound', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def post(self, request): data = request.data diff --git a/server/api/views/risk/views_riskWithSources.py b/server/api/views/risk/views_riskWithSources.py index c02908fc..26cad9f8 100644 --- a/server/api/views/risk/views_riskWithSources.py +++ b/server/api/views/risk/views_riskWithSources.py @@ -1,7 +1,8 @@ from rest_framework.views import APIView from rest_framework.response import Response -from rest_framework import status +from rest_framework import status, serializers as drf_serializers from rest_framework.permissions import AllowAny +from drf_spectacular.utils import extend_schema, inline_serializer from api.views.listMeds.models import Medication from api.models.model_medRule import MedRule, MedRuleSource import openai @@ -11,6 +12,28 @@ class RiskWithSourcesView(APIView): permission_classes = [AllowAny] + @extend_schema( + request=inline_serializer(name='RiskWithSourcesRequest', fields={ + 'drug': drf_serializers.CharField(help_text='Medication name'), + 'source': drf_serializers.CharField(required=False, help_text='One of: include, diagnosis, diagnosis_depressed, diagnosis_manic, diagnosis_hypomanic, diagnosis_euthymic'), + }), + responses={ + 200: inline_serializer(name='RiskWithSourcesResponse', fields={ + 'benefits': drf_serializers.ListField(child=drf_serializers.CharField()), + 'risks': drf_serializers.ListField(child=drf_serializers.CharField()), + 'sources': drf_serializers.ListField(child=drf_serializers.DictField()), + 'medrules_found': drf_serializers.IntegerField(required=False), + 'source_type': drf_serializers.CharField(required=False), + 'note': drf_serializers.CharField(required=False), + }), + 400: inline_serializer(name='RiskWithSourcesBadRequest', fields={ + 'error': drf_serializers.CharField(), + }), + 404: inline_serializer(name='RiskWithSourcesNotFound', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def post(self, request): openai.api_key = os.environ.get("OPENAI_API_KEY") diff --git a/server/api/views/text_extraction/views.py b/server/api/views/text_extraction/views.py index e4122851..020740ad 100644 --- a/server/api/views/text_extraction/views.py +++ b/server/api/views/text_extraction/views.py @@ -9,6 +9,8 @@ from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt import anthropic +from drf_spectacular.utils import extend_schema, inline_serializer, OpenApiParameter +from rest_framework import serializers as drf_serializers from ...services.openai_services import openAIServices from api.models.model_embeddings import Embeddings @@ -97,6 +99,20 @@ class RuleExtractionAPIView(APIView): permission_classes = [IsAuthenticated] + @extend_schema( + parameters=[ + OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=True, description='File GUID to extract rules from'), + ], + responses={ + 200: inline_serializer(name='RuleExtractionResponse', fields={ + 'texts': drf_serializers.CharField(), + 'cited_texts': drf_serializers.CharField(), + }), + 500: inline_serializer(name='RuleExtractionError', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def get(self, request): try: @@ -141,6 +157,19 @@ def openai_extraction(content_chunks, user_prompt): class RuleExtractionAPIOpenAIView(APIView): permission_classes = [IsAuthenticated] + @extend_schema( + parameters=[ + OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=True, description='File GUID to extract rules from'), + ], + responses={ + 200: inline_serializer(name='RuleExtractionOpenAIResponse', fields={ + 'rules': drf_serializers.ListField(child=drf_serializers.DictField()), + }), + 500: inline_serializer(name='RuleExtractionOpenAIError', fields={ + 'error': drf_serializers.CharField(), + }), + } + ) def get(self, request): try: user_prompt = """ diff --git a/server/api/views/uploadFile/test_title.py b/server/api/views/uploadFile/test_title.py index 69979620..ef694e14 100644 --- a/server/api/views/uploadFile/test_title.py +++ b/server/api/views/uploadFile/test_title.py @@ -4,6 +4,39 @@ from . import title +def make_page_dict(blocks): + """Helper to build a get_text("dict") return value from a simple list of blocks. + Each block is a list of (text, font_size) tuples representing spans. + """ + dict_blocks = [] + for spans in blocks: + dict_blocks.append({ + "type": 0, + "lines": [{ + "spans": [{"text": text, "size": size} for text, size in spans] + }] + }) + return {"blocks": dict_blocks} + + +def make_mock_doc(pages_data, metadata=None): + """Build a mock fitz.Document. + pages_data: list of block lists, one per page. Each block is a list of (text, size) tuples. + """ + doc = MagicMock() + doc.metadata = metadata or {"title": None} + doc.__len__ = lambda self: len(pages_data) + + mock_pages = [] + for page_blocks in pages_data: + page = MagicMock() + page.get_text.return_value = make_page_dict(page_blocks) + mock_pages.append(page) + + doc.__getitem__ = lambda self, idx: mock_pages[idx] + return doc + + class TestGenerateTitle(unittest.TestCase): def test_prefers_metadata_title_if_valid(self): doc = MagicMock() @@ -11,59 +44,112 @@ def test_prefers_metadata_title_if_valid(self): self.assertEqual( "A Study Regarding The Efficacy of Drugs", title.generate_title(doc)) - def test_falls_back_to_first_page_text_if_metadata_title_is_empty(self): - doc = MagicMock() - doc.metadata = {"title": ""} - doc[0].get_text = MagicMock() - - foo_block = [None] * 7 - foo_block[4] = "foo" - foo_block[6] = 0 - - title_block = [None] * 7 - title_block[4] = "Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia" - title_block[6] = 0 - - bar_block = [None] * 7 - bar_block[4] = "bar" - bar_block[6] = 0 - doc[0].get_text.return_value = [foo_block, title_block, bar_block] - + def test_falls_back_to_font_size_if_metadata_title_is_empty(self): + doc = make_mock_doc( + pages_data=[[ + [("foo", 10.0)], + [("Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia", 18.0)], + [("bar", 10.0)], + ]], + metadata={"title": ""}, + ) expected_title = "Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia" self.assertEqual(expected_title, title.generate_title(doc)) - def test_falls_back_to_first_page_text_if_metadata_title_does_not_match_regex(self): - doc = MagicMock() - doc.metadata = {"title": "abcd1234"} - doc[0].get_text = MagicMock() - - foo_block = [None] * 7 - foo_block[4] = "foo" - foo_block[6] = 0 - - title_block = [None] * 7 - title_block[4] = "Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia" - title_block[6] = 0 - - bar_block = [None] * 7 - bar_block[4] = "bar" - bar_block[6] = 0 - doc[0].get_text.return_value = [foo_block, title_block, bar_block] - + def test_falls_back_to_font_size_if_metadata_title_does_not_match_regex(self): + doc = make_mock_doc( + pages_data=[[ + [("foo", 10.0)], + [("Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia", 18.0)], + [("bar", 10.0)], + ]], + metadata={"title": "abcd1234"}, + ) expected_title = "Advances in Mood Disorder Pharmacotherapy: Evaluating New Antipsychotics and Mood Stabilizers for Bipolar Disorder and Schizophrenia" self.assertEqual(expected_title, title.generate_title(doc)) - @patch("api.services.openai_services.openAIServices.openAI") + @patch("api.views.uploadFile.title.openAIServices.openAI") def test_falls_back_to_chatgpt_if_no_title_found(self, mock_openAI): - doc = MagicMock() - doc.metadata = {"title": None} - doc.get_text.return_value = [] + doc = make_mock_doc( + pages_data=[[]] # no blocks at all + ) - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "A Study Regarding The Efficacy of Drugs" - mock_openAI.return_value = mock_response + mock_openAI.return_value = "A Study Regarding The Efficacy of Drugs" - title.generate_title(doc) + result = title.generate_title(doc) self.assertTrue(mock_openAI.called) + self.assertEqual(result, "A Study Regarding The Efficacy of Drugs") + + @patch("api.views.uploadFile.title.openAIServices.openAI") + def test_strips_quotes_from_openai_title(self, mock_openAI): + doc = make_mock_doc(pages_data=[[]]) + + mock_openAI.return_value = '"Updated CANMAT/ISBD Guidelines for Treating Mixed Features in Bipolar Disorder"' + + result = title.generate_title(doc) + + self.assertEqual(result, "Updated CANMAT/ISBD Guidelines for Treating Mixed Features in Bipolar Disorder") + + @patch("api.views.uploadFile.title.openAIServices.openAI") + def test_truncates_long_openai_title(self, mock_openAI): + doc = make_mock_doc(pages_data=[[]]) + + mock_openAI.return_value = "A" * 300 + + result = title.generate_title(doc) + + # Ensure the title is truncated to fit the UploadFile model's title field (max_length=255), since OpenAI responses may exceed this limit + self.assertLessEqual(len(result), 255) + + def test_font_size_joins_adjacent_spans_in_same_block(self): + """A title split across multiple spans in the same block should be joined.""" + doc = make_mock_doc( + pages_data=[[ + [("Author Name", 10.0)], + [("Advances in Mood Disorder", 18.0), ("Pharmacotherapy", 18.0)], + [("Some journal info", 10.0)], + ]], + ) + result = title.extract_title_by_font_size(doc) + self.assertEqual(result, "Advances in Mood Disorder Pharmacotherapy") + + def test_font_size_ignores_short_spans(self): + """Superscript markers and other tiny spans should be filtered out.""" + doc = make_mock_doc( + pages_data=[[ + [("Advances in Mood Disorder Pharmacotherapy", 18.0), ("*", 18.0)], + [("Author Name et al.", 10.0)], + ]], + ) + # The "*" span is < 2 chars, so it should be ignored; title is just the real text + result = title.extract_title_by_font_size(doc) + self.assertEqual(result, "Advances in Mood Disorder Pharmacotherapy") + + def test_font_size_returns_none_when_no_regex_match(self): + """If the largest-font text doesn't match the title regex, return None.""" + doc = make_mock_doc( + pages_data=[[ + # Only 2 words — regex requires at least 3 + [("Psychiatry Research", 18.0)], + [("Author Name et al.", 10.0)], + ]], + ) + result = title.extract_title_by_font_size(doc) + self.assertIsNone(result) + + def test_font_size_finds_title_on_later_page(self): + """Title on page 2 should still be found if it has the largest font.""" + doc = make_mock_doc( + pages_data=[ + [ # page 1: cover page with smaller text + [("Some preamble text here", 12.0)], + ], + [ # page 2: actual title in larger font + [("Advances in Mood Disorder Pharmacotherapy", 18.0)], + [("Author Name et al.", 10.0)], + ], + ], + ) + result = title.extract_title_by_font_size(doc) + self.assertEqual(result, "Advances in Mood Disorder Pharmacotherapy") diff --git a/server/api/views/uploadFile/title.py b/server/api/views/uploadFile/title.py index 06e0ce0c..38dcd5d5 100644 --- a/server/api/views/uploadFile/title.py +++ b/server/api/views/uploadFile/title.py @@ -6,44 +6,89 @@ # regular expression to match common research white paper titles. Created by Chat-gpt -# requires at least 3 words, no dates, no version numbers. +# requires at least 3 words, no version numbers. title_regex = re.compile( - r'^(?=(?:\b\w+\b[\s:,\-\(\)]*){3,})(?!.*\b(?:19|20)\d{2}\b)(?!.*\bv\d+\b)[A-Za-z0-9][\w\s:,\-\(\)]*[A-Za-z\)]$', re.IGNORECASE) + r"^(?=(?:\b\w+\b[^A-Za-z0-9]*){3,})(?!.*\bv\d+\b)[A-Za-z0-9].+[A-Za-z\)?!]$", re.IGNORECASE) def generate_title(pdf: fitz.Document) -> str | None: document_metadata_title = pdf.metadata["title"] if document_metadata_title is not None and document_metadata_title != "": if title_regex.match(document_metadata_title): - print("suitable title was found in metadata") return document_metadata_title.strip() - else: - print("metadata title did not match regex") - print("Looking for title in first page text") - first_page = pdf[0] - first_page_blocks = first_page.get_text("blocks") - text_blocks = [ - block[4].strip().replace("\n", " ") - for block in first_page_blocks - if block[6] == 0 # only include text blocks. - ] - - # For some reason, extracted PDF text has extra spaces. Collapse them here. - regex = r"\s{2,}" - text_blocks = [re.sub(regex, " ", text) for text in text_blocks] - - if len(text_blocks) != 0: - for text in text_blocks: - if title_regex.match(text): - return text - - print( - "no suitable title found in first page text. Using GPT-4 to summarize the PDF") + font_title = extract_title_by_font_size(pdf) + if font_title: + return font_title + gpt_title = summarize_pdf(pdf) return gpt_title or None +def extract_title_by_font_size(pdf: fitz.Document, max_pages: int = 3) -> str | None: + """ + Extract the title by finding the largest font size across the first few pages + and collecting contiguous runs of text at that size. + """ + pages_to_scan = min(max_pages, len(pdf)) + + # First pass: collect all spans with their font size, and find the max font size. + all_spans = [] + max_font_size = 0.0 + + for page_idx in range(pages_to_scan): + page_dict = pdf[page_idx].get_text("dict") + for block in page_dict["blocks"]: + if block.get("type") != 0: + continue + for line in block["lines"]: + for span in line["spans"]: + text = span["text"].strip() + size = span["size"] + if len(text) < 2 or size < 6.0: + continue + all_spans.append({"text": text, "size": size}) + if size > max_font_size: + max_font_size = size + + if max_font_size == 0.0: + return None + + # Second pass: gather contiguous runs of spans at the max font size. + # Runs continue across block boundaries so multi-block titles (e.g., + # "BIPOLAR DISORDER IN PRIMARY CARE:" in one block and "DIAGNOSIS AND + # MANAGEMENT" in the next) are joined into a single candidate. + # A run only ends when a non-max-size span interrupts it. + candidates = [] + current_run = [] + + for span in all_spans: + if span["size"] == max_font_size: + current_run.append(span["text"]) + else: + if current_run: + candidates.append(" ".join(current_run)) + current_run = [] + + if current_run: + candidates.append(" ".join(current_run)) + + # Collapse extra whitespace, validate against title regex, and pick the longest match. + # Longest wins because real titles are typically longer than section headers + # (e.g., "About the Author") that may share the same max font size. + best = None + for candidate in candidates: + cleaned = re.sub(r"\s{2,}", " ", candidate).strip() + if title_regex.match(cleaned): + if best is None or len(cleaned) > len(best): + best = cleaned + + if best: + return best[:255] + + return None + + def summarize_pdf(pdf: fitz.Document) -> str: """ Summarize a PDF document using OpenAI's GPT-4 model. @@ -58,4 +103,6 @@ def summarize_pdf(pdf: fitz.Document) -> str: prompt = "Please provide a title for this document. The title should be less than 256 characters and will be displayed on a webpage." response = openAIServices.openAI( first_page_content, prompt, model='gpt-4o', temp=0.0) - return response.choices[0].message.content + title = response.strip().strip('"').strip("'") + # Truncate to fit UploadFile model's max_length=255 title field as a final safeguard + return title[:255] diff --git a/server/api/views/uploadFile/views.py b/server/api/views/uploadFile/views.py index 69dfb996..eda43b76 100644 --- a/server/api/views/uploadFile/views.py +++ b/server/api/views/uploadFile/views.py @@ -1,8 +1,9 @@ from rest_framework.views import APIView from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response -from rest_framework import status +from rest_framework import status, serializers as drf_serializers from rest_framework.generics import UpdateAPIView +from drf_spectacular.utils import extend_schema, inline_serializer, OpenApiResponse import pdfplumber from .models import UploadFile # Import your UploadFile model from .serializers import UploadFileSerializer @@ -12,9 +13,14 @@ import fitz from django.db import transaction from .title import generate_title +import logging + +logger = logging.getLogger(__name__) class UploadFileView(APIView): + serializer_class = UploadFileSerializer + def get_permissions(self): if self.request.method == 'GET': return [AllowAny()] # Public access @@ -28,6 +34,23 @@ def get(self, request, format=None): serializer = UploadFileSerializer(files, many=True) return Response(serializer.data) + @extend_schema( + request={'multipart/form-data': inline_serializer( + name='UploadFileRequest', + fields={ + 'file': drf_serializers.FileField(help_text='PDF file to upload'), + } + )}, + responses={ + 201: inline_serializer(name='UploadFileSuccess', fields={ + 'message': drf_serializers.CharField(), + 'file_id': drf_serializers.IntegerField(), + }), + 400: inline_serializer(name='UploadFileBadRequest', fields={ + 'message': drf_serializers.CharField(), + }), + } + ) def post(self, request, format=None): print(request.auth) print(f"UploadFileView post called. Path: {request.path}") @@ -124,9 +147,26 @@ def post(self, request, format=None): ) except Exception as e: # Handle potential errors + logger.exception("File upload failed for '%s': %s", uploaded_file.name, e) return Response({"message": f"Error processing file and embeddings: {str(e)}"}, status=status.HTTP_400_BAD_REQUEST) + @extend_schema( + request=inline_serializer(name='DeleteFileRequest', fields={ + 'guid': drf_serializers.CharField(help_text='GUID of file to delete'), + }), + responses={ + 200: inline_serializer(name='DeleteFileSuccess', fields={ + 'message': drf_serializers.CharField(), + }), + 403: inline_serializer(name='DeleteFileForbidden', fields={ + 'message': drf_serializers.CharField(), + }), + 404: inline_serializer(name='DeleteFileNotFound', fields={ + 'message': drf_serializers.CharField(), + }), + } + ) def delete(self, request, format=None): guid = request.data.get('guid') if not guid: @@ -157,6 +197,14 @@ def delete(self, request, format=None): class RetrieveUploadFileView(APIView): permission_classes = [AllowAny] + @extend_schema( + responses={ + (200, 'application/pdf'): OpenApiResponse(description='PDF file binary content'), + 404: inline_serializer(name='RetrieveFileNotFound', fields={ + 'message': drf_serializers.CharField(), + }), + } + ) def get(self, request, guid, format=None): try: file = UploadFile.objects.get(guid=guid) diff --git a/server/api/views/version/views.py b/server/api/views/version/views.py index b79d6577..af59e9e0 100644 --- a/server/api/views/version/views.py +++ b/server/api/views/version/views.py @@ -3,11 +3,18 @@ from rest_framework.permissions import AllowAny from rest_framework.views import APIView from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from drf_spectacular.utils import extend_schema, inline_serializer class VersionView(APIView): permission_classes = [AllowAny] + @extend_schema( + responses={200: inline_serializer(name='VersionResponse', fields={ + 'version': drf_serializers.CharField(), + })} + ) def get(self, request, *args, **kwargs): version = os.environ.get("VERSION") or "dev" return Response({"version": version}) diff --git a/server/balancer_backend/settings.py b/server/balancer_backend/settings.py index 9f917a94..a4ccaaae 100644 --- a/server/balancer_backend/settings.py +++ b/server/balancer_backend/settings.py @@ -51,6 +51,7 @@ "corsheaders", "rest_framework", "djoser", + 'drf_spectacular', ] MIDDLEWARE = [ @@ -195,8 +196,19 @@ "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework_simplejwt.authentication.JWTAuthentication", ), + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', } +SPECTACULAR_SETTINGS = { + 'TITLE': 'Balancer API', + 'DESCRIPTION': 'API for the Balancer medication decision support tool', + 'VERSION': '1.0.0', + 'SERVE_INCLUDE_SCHEMA': False, + 'SECURITY': [{'jwtAuth': []}], + 'SWAGGER_UI_SETTINGS': { + 'persistAuthorization': True, + }, +} SIMPLE_JWT = { "AUTH_HEADER_TYPES": ("JWT",), diff --git a/server/balancer_backend/urls.py b/server/balancer_backend/urls.py index c8bd290d..55bd2032 100644 --- a/server/balancer_backend/urls.py +++ b/server/balancer_backend/urls.py @@ -6,6 +6,9 @@ # Import TemplateView for rendering templates from django.views.generic import TemplateView import importlib # Import the importlib module for dynamic module importing +from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView, SpectacularRedocView + + # Define a list of URL patterns for the application # Keep admin outside /api/ prefix @@ -50,6 +53,9 @@ # Wrap all API routes under /api/ prefix urlpatterns += [ path("api/", include(api_urlpatterns)), + path("api/schema/", SpectacularAPIView.as_view(), name="schema"), + path("api/docs/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"), + path("api/redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"), ] import os diff --git a/server/pytest.ini b/server/pytest.ini new file mode 100644 index 00000000..235b9752 --- /dev/null +++ b/server/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +DJANGO_SETTINGS_MODULE = balancer_backend.settings +pythonpath = . diff --git a/server/requirements.txt b/server/requirements.txt index bbaf7bc9..f952b200 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -18,4 +18,7 @@ sentence_transformers PyMuPDF==1.24.0 Pillow pytesseract -anthropic \ No newline at end of file +anthropic +pytest +pytest-django +drf-spectacular