This commit is contained in:
Markos Gogoulos
2025-12-30 18:38:07 +02:00
parent d3c858173f
commit 5070050fa4
2 changed files with 46 additions and 15 deletions

View File

@@ -4,8 +4,10 @@ PyLTI1p3 Django adapters for MediaCMS
Provides Django-specific implementations for PyLTI1p3 interfaces Provides Django-specific implementations for PyLTI1p3 interfaces
""" """
import hashlib
import json import json
import time import time
import uuid
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import jwt import jwt
@@ -192,8 +194,10 @@ class DjangoServiceConnector(ServiceConnector):
self._access_token_expires = 0 self._access_token_expires = 0
def get_access_token(self, scopes): def get_access_token(self, scopes):
if self._access_token and time.time() < self._access_token_expires: cache_key = 'lti_access_token_' + self._registration.get_issuer() + '_' + hashlib.sha1(' '.join(scopes).encode('utf-8')).hexdigest()
return self._access_token token_data = cache.get(cache_key)
if token_data:
return token_data['access_token']
key_obj = LTIToolKeys.get_or_create_keys() key_obj = LTIToolKeys.get_or_create_keys()
jwk_obj = jwk.JWK(**key_obj.private_key_jwk) jwk_obj = jwk.JWK(**key_obj.private_key_jwk)
@@ -201,13 +205,17 @@ class DjangoServiceConnector(ServiceConnector):
private_key = serialization.load_pem_private_key(pem_bytes, password=None, backend=default_backend()) private_key = serialization.load_pem_private_key(pem_bytes, password=None, backend=default_backend())
now = int(time.time()) now = int(time.time())
# Moodle can be picky about audience. Including both token URL and issuer is safer.
audience = [self._registration.get_auth_token_url(), self._registration.get_issuer()]
payload = { payload = {
'iss': self._registration.get_client_id(), 'iss': self._registration.get_client_id(),
'sub': self._registration.get_client_id(), 'sub': self._registration.get_client_id(),
'aud': self._registration.get_auth_token_url(), 'aud': audience,
'iat': now, 'iat': now,
'exp': now + 300, 'exp': now + 300,
'jti': str(time.time()), 'jti': str(uuid.uuid4()),
} }
client_assertion = jwt.encode(payload, private_key, algorithm='RS256', headers={'kid': key_obj.private_key_jwk['kid']}) client_assertion = jwt.encode(payload, private_key, algorithm='RS256', headers={'kid': key_obj.private_key_jwk['kid']})
@@ -220,15 +228,24 @@ class DjangoServiceConnector(ServiceConnector):
'scope': ' '.join(scopes), 'scope': ' '.join(scopes),
} }
print(f"LTI Service: Requesting access token from {token_url} with scopes: {scopes}")
response = requests.post(token_url, data=data, timeout=10) response = requests.post(token_url, data=data, timeout=10)
response.raise_for_status()
token_data = response.json() try:
self._access_token = token_data['access_token'] response.raise_for_status()
expires_in = token_data.get('expires_in', 3600) token_data = response.json()
self._access_token_expires = time.time() + expires_in - 10 print(f"LTI Service: Successfully received access token. Expires in: {token_data.get('expires_in', 'N/A')}")
return self._access_token expires_in = token_data.get('expires_in', 3600)
cache.set(cache_key, token_data, timeout=expires_in - 10)
return token_data['access_token']
except requests.exceptions.HTTPError as e:
print(f"LTI Service Error: Failed to get access token. Status: {e.response.status_code}, Response: {e.response.text}")
raise
except json.JSONDecodeError:
print(f"LTI Service Error: Failed to decode JSON from token endpoint. Response: {response.text}")
raise
def make_service_request(self, scopes, url, is_post=False, data=None, **kwargs): def make_service_request(self, scopes, url, is_post=False, data=None, **kwargs):
access_token = self.get_access_token(scopes) access_token = self.get_access_token(scopes)

View File

@@ -158,11 +158,25 @@ class LaunchView(View):
unverified = jwt.decode(id_token, options={"verify_signature": False}) unverified = jwt.decode(id_token, options={"verify_signature": False})
iss = unverified.get('iss') iss = unverified.get('iss')
aud = unverified.get('aud') aud = unverified.get('aud') # Can be a string or a list
try:
platform = LTIPlatform.objects.get(platform_id=iss, client_id=aud) platform = None
except LTIPlatform.DoesNotExist: if isinstance(aud, list):
raise # If aud is a list, find a platform where the client_id is in the list
platforms = LTIPlatform.objects.filter(platform_id=iss, client_id__in=aud)
if platforms.count() == 1:
platform = platforms.first()
elif platforms.count() > 1:
raise LtiException(f"Multiple platforms found for issuer '{iss}' and client_ids '{aud}'")
else:
# If aud is a string, find it directly
try:
platform = LTIPlatform.objects.get(platform_id=iss, client_id=aud)
except LTIPlatform.DoesNotExist:
pass # Platform will be None
if not platform:
raise LtiException(f"Platform not found for issuer '{iss}' and client_id(s) '{aud}'")
tool_config = DjangoToolConfig.from_platform(platform) tool_config = DjangoToolConfig.from_platform(platform)