first sync
Some checks failed
Deployment Verification / deploy-and-test (push) Failing after 29s

This commit is contained in:
2025-03-04 07:59:21 +01:00
parent 9cdcf486b6
commit 506716e703
1450 changed files with 577316 additions and 62 deletions

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from app.models import Cases
from app.models.authorization import Group
from app.models.authorization import GroupCaseAccess
from app.models.authorization import Organisation
from app.models.authorization import OrganisationCaseAccess
from app.models.authorization import User
from app.models.authorization import UserCaseAccess
def manage_ac_audit_users_db():
uca = UserCaseAccess.query.with_entities(
User.name,
User.user,
User.id,
User.uuid,
UserCaseAccess.access_level,
Cases.name,
Cases.case_id
).join(
UserCaseAccess.case,
UserCaseAccess.user
).all()
gca = GroupCaseAccess.query.with_entities(
Group.group_name,
Group.group_id,
Group.group_uuid,
GroupCaseAccess.access_level,
Cases.name,
Cases.case_id
).join(
GroupCaseAccess.case,
GroupCaseAccess.group
).all()
oca = OrganisationCaseAccess.query.with_entities(
Organisation.org_name,
Organisation.org_id,
Organisation.org_uuid,
OrganisationCaseAccess.access_level,
Cases.name,
Cases.case_id
).all()
ret = {
'users': [u._asdict() for u in uca],
'groups': [g._asdict() for g in gca],
'organisations': [o._asdict() for o in oca]
}
return ret

View File

@ -0,0 +1,292 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# Copyright (C) 2021 - Airbus CyberSecurity (SAS)
# ir@cyberactionlab.net
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import json
import logging as logger
from sqlalchemy.orm.attributes import flag_modified
from app import db, app
from app.models import CaseAssets
from app.models import CaseReceivedFile
from app.models import CaseTasks
from app.models import Cases
from app.models import CasesEvent
from app.models import Client
from app.models import CustomAttribute
from app.models import Ioc
from app.models import Notes
log = logger.getLogger(__name__)
def update_all_attributes(object_type, previous_attribute, partial_overwrite=False, complete_overwrite=False):
obj_list = []
if object_type == 'ioc':
obj_list = Ioc.query.all()
elif object_type == 'event':
obj_list = CasesEvent.query.all()
elif object_type == 'asset':
obj_list = CaseAssets.query.all()
elif object_type == 'task':
obj_list = CaseTasks.query.all()
elif object_type == 'note':
obj_list = Notes.query.all()
elif object_type == 'evidence':
obj_list = CaseReceivedFile.query.all()
elif object_type == 'case':
obj_list = Cases.query.all()
elif object_type == 'client':
obj_list = Client.query.all()
target_attr = get_default_custom_attributes(object_type)
app.logger.info(f'Migrating {len(obj_list)} objects of type {object_type}')
for obj in obj_list:
if complete_overwrite or obj.custom_attributes is None:
app.logger.info('Achieving complete overwrite')
obj.custom_attributes = target_attr
flag_modified(obj, "custom_attributes")
db.session.commit()
continue
for tab in target_attr:
if obj.custom_attributes.get(tab) is None or partial_overwrite:
app.logger.info(f'Migrating {tab}')
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab] = target_attr[tab]
else:
for element in target_attr[tab]:
if element not in obj.custom_attributes[tab]:
app.logger.info(f'Migrating {element}')
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab][element] = target_attr[tab][element]
else:
if obj.custom_attributes[tab][element]['type'] != target_attr[tab][element]['type']:
if (obj.custom_attributes[tab][element]['value'] == target_attr[tab][element]['value']) or \
(obj.custom_attributes[tab][element]['type'] in ('input_string', 'input_text_field') and
target_attr[tab][element]['type'] in ('input_string', 'input_text_field')):
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab][element]['type'] = target_attr[tab][element]['type']
if 'mandatory' in target_attr[tab][element] \
and obj.custom_attributes[tab][element]['mandatory'] != target_attr[tab][element]['mandatory']:
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab][element]['mandatory'] = target_attr[tab][element]['mandatory']
if partial_overwrite:
for tab in previous_attribute:
if not target_attr.get(tab):
if obj.custom_attributes.get(tab):
flag_modified(obj, "custom_attributes")
obj.custom_attributes.pop(tab)
for element in previous_attribute[tab]:
if target_attr.get(tab):
if not target_attr[tab].get(element):
if obj.custom_attributes[tab].get(element):
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab].pop(element)
# Commit will only be effective if we flagged a modification, reducing load on the DB
db.session.commit()
def get_default_custom_attributes(object_type):
ca = CustomAttribute.query.filter(CustomAttribute.attribute_for == object_type).first()
return ca.attribute_content
def add_tab_attribute(obj, tab_name):
"""
Add a new custom tab to an object ID
"""
if not obj:
return False
attribute = obj.custom_attributes
if tab_name in attribute:
return True
else:
attribute[tab_name] = {}
flag_modified(obj, "custom_attributes")
db.session.commit()
return True
def add_tab_attribute_field(obj, tab_name, field_name, field_type, field_value, mandatory=None, field_options=None):
if not obj:
return False
attribute = obj.custom_attributes
if attribute is None:
attribute = {}
if tab_name not in attribute:
attribute[tab_name] = {}
attr = {
field_name: {
"mandatory": mandatory if mandatory is not None else False,
"type": field_type,
"value": field_value
}
}
if field_options:
attr[field_name]['options'] = field_options
attribute[tab_name][field_name] = attr[field_name]
obj.custom_attributes = attribute
flag_modified(obj, "custom_attributes")
db.session.commit()
return True
def merge_custom_attributes(data, obj_id, object_type, overwrite=False):
obj = None
if obj_id:
if object_type == 'ioc':
obj = Ioc.query.filter(Ioc.ioc_id == obj_id).first()
elif object_type == 'event':
obj = CasesEvent.query.filter(CasesEvent.event_id == obj_id).first()
elif object_type == 'asset':
obj = CaseAssets.query.filter(CaseAssets.asset_id == obj_id).first()
elif object_type == 'task':
obj = CaseTasks.query.filter(CaseTasks.id == obj_id).first()
elif object_type == 'note':
obj = Notes.query.filter(Notes.note_id == obj_id).first()
elif object_type == 'evidence':
obj = CaseReceivedFile.query.filter(CaseReceivedFile.id == obj_id).first()
elif object_type == 'case':
obj = Cases.query.filter(Cases.case_id == obj_id).first()
elif object_type == 'client':
obj = Client.query.filter(Client.client_id == obj_id).first()
if not obj:
return data
if overwrite:
log.warning(f'Overwriting all {object_type}')
return get_default_custom_attributes(object_type)
for tab in data:
if obj.custom_attributes.get(tab) is None:
log.error(f'Missing tab {tab} in {object_type}')
continue
for field in data[tab]:
if field not in obj.custom_attributes[tab]:
log.error(f'Missing field {field} in {object_type}')
else:
if obj.custom_attributes[tab][field]['type'] == 'html':
continue
if obj.custom_attributes[tab][field]['value'] != data[tab][field]:
flag_modified(obj, "custom_attributes")
obj.custom_attributes[tab][field]['value'] = data[tab][field]
# Commit will only be effective if we flagged a modification, reducing load on the DB
db.session.commit()
return obj.custom_attributes
else:
default_attr = get_default_custom_attributes(object_type)
for tab in data:
if default_attr.get(tab) is None:
app.logger.info(f'Missing tab {tab} in {object_type} default attribute')
continue
for field in data[tab]:
if field not in default_attr[tab]:
app.logger.info(f'Missing field {field} in {object_type} default attribute')
else:
default_attr[tab][field]['value'] = data[tab][field]
return default_attr
def validate_attribute(attribute):
logs = []
try:
data = json.loads(attribute)
except Exception as e:
return None, [str(e)]
for tab in data:
for field in data[tab]:
if not data[tab][field].get('type'):
logs.append(f'{tab}::{field} is missing mandatory "type" tag')
continue
field_type = data[tab][field].get('type')
if field_type in ['input_string', 'input_textfield', 'input_checkbox', 'input_select',
'input_date', 'input_datetime']:
if data[tab][field].get('mandatory') is None:
logs.append(f'{tab} -> {field} of type {field_type} is missing mandatory "mandatory" tag')
elif not isinstance(data[tab][field].get('mandatory'), bool):
logs.append(f'{tab} -> {field} -> "mandatory" expects a value of type bool, '
f'but got {type(data[tab][field].get("mandatory"))}')
if data[tab][field].get('value') is None:
logs.append(f'{tab} -> {field} of type {field_type} is missing mandatory "value" tag')
if field_type == 'input_checkbox' and not isinstance(data[tab][field].get('value'), bool):
logs.append(f'{tab} -> {field} of type {field_type} expects a value of type bool, '
f'but got {type(data[tab][field]["value"])}')
if field_type in ['input_string', 'input_textfield', 'input_date', 'input_datetime']:
if not isinstance(data[tab][field].get('value'), str):
logs.append(f'{tab} -> {field} of type {field_type} expects a value of type str, '
f'but got {type(data[tab][field]["value"])}')
if field_type == 'input_select':
if data[tab][field].get('options') is None:
logs.append(f'{tab} -> {field} of type {field_type} is missing mandatory "options" tag')
continue
if not isinstance(data[tab][field].get('options'), list):
logs.append(f'{tab} -> {field} of type {field_type} expects a value of type list, '
f'but got {type(data[tab][field]["value"])}')
for opt in data[tab][field].get('options'):
if not isinstance(opt, str):
logs.append(f'{tab} -> {field} -> "options" expects a list of str, '
f'but got {type(opt)}')
elif field_type in ['raw', 'html']:
if data[tab][field].get('value') is None:
logs.append(f'{tab} -> {field} of type {field_type} is missing mandatory "value" tag')
else:
logs.append(f'{tab} -> {field}, unknown field type "{field_type}"')
return data, logs

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from sqlalchemy import func
from typing import List
from app.models import CaseClassification
def get_case_classifications_list() -> List[dict]:
"""Get a list of case classifications
Returns:
List[dict]: List of case classifications
"""
case_classifications = CaseClassification.query.with_entities(
CaseClassification.id,
CaseClassification.name,
CaseClassification.name_expanded,
CaseClassification.description,
CaseClassification.creation_date
).all()
c_cl = [row._asdict() for row in case_classifications]
return c_cl
def get_case_classification_by_id(cur_id: int) -> CaseClassification:
"""Get a case classification
Args:
cur_id (int): case classification id
Returns:
CaseClassification: Case classification
"""
case_classification = CaseClassification.query.filter_by(id=cur_id).first()
return case_classification
def get_case_classification_by_name(cur_name: str) -> CaseClassification:
"""Get a case classification
Args:
cur_name (str): case classification name
Returns:
CaseClassification: Case classification
"""
case_classification = CaseClassification.query.filter_by(name=cur_name).first()
return case_classification
def search_classification_by_name(name: str, exact_match: bool = False) -> List[dict]:
"""Search for a case classification by name
Args:
name (str): case classification name
exact_match (bool, optional): Exact match. Defaults to False.
Returns:
List[dict]: List of case classifications
"""
if exact_match:
return CaseClassification.query.filter(func.lower(CaseClassification.name) == name.lower()).all()
return CaseClassification.query.filter(CaseClassification.name.ilike(f'%{name}%')).all()

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# Copyright (C) 2021 - Airbus CyberSecurity (SAS)
# ir@cyberactionlab.net
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from sqlalchemy import func
from app.models import AnalysisStatus, IocType, AssetsType, EventCategory
def search_analysis_status_by_name(name: str, exact_match: bool = False) -> AnalysisStatus:
"""
Search an analysis status by its name
args:
name: the name of the analysis status
exact_match: if True, the name must be exactly the same as the one in the database
return: the analysis status
"""
if exact_match:
return AnalysisStatus.query.filter(func.lower(AnalysisStatus.name) == name.lower()).all()
return AnalysisStatus.query.filter(AnalysisStatus.name.ilike(f'%{name}%')).all()
def search_ioc_type_by_name(name: str, exact_match: bool = False) -> IocType:
"""
Search an IOC type by its name
args:
name: the name of the IOC type
exact_match: if True, the name must be exactly the same as the one in the database
return: the IOC type
"""
if exact_match:
return IocType.query.filter(func.lower(IocType.type_name) == name.lower()).all()
return IocType.query.filter(IocType.type_name.ilike(f'%{name}%')).all()
def search_asset_type_by_name(name: str, exact_match: bool = False) -> AssetsType:
"""
Search an asset type by its name
args:
name: the name of the asset type
exact_match: if True, the name must be exactly the same as the one in the database
return: the asset type
"""
if exact_match:
return AssetsType.query.filter(func.lower(AssetsType.asset_name) == name.lower()).all()
return AssetsType.query.filter(AssetsType.asset_name.ilike(f'%{name}%')).all()
def search_event_category_by_name(name: str, exact_match: bool = False) -> AssetsType:
"""
Search an event category by its name
args:
name: the name of the event category
exact_match: if True, the name must be exactly the same as the one in the database
return: the event category
"""
if exact_match:
return EventCategory.query.filter(func.lower(EventCategory.name) == name.lower()).all()
return EventCategory.query.filter(EventCategory.name.ilike(f'%{name}%')).all()

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from typing import List
from app.models.cases import CaseState
from app.schema.marshables import CaseStateSchema
def get_case_states_list() -> List[dict]:
"""Get a list of case state
Returns:
List[dict]: List of case state
"""
case_state = CaseState.query.all()
return CaseStateSchema(many=True).dump(case_state)
def get_case_state_by_id(cur_id: int) -> CaseState:
"""Get a case state
Args:
cur_id (int): case state id
Returns:
CaseState: Case state
"""
case_state = CaseState.query.filter_by(state_id=cur_id).first()
return case_state
def get_case_state_by_name(cur_name: str) -> CaseState:
"""Get a case state
Args:
cur_name (str): case state name
Returns:
CaseState: Case state
"""
case_state = CaseState.query.filter_by(state_name=cur_name).first()
return case_state
def get_cases_using_state(cur_id: int) -> List[dict]:
"""Get a list of cases using a case state
Args:
cur_id (int): case state id
Returns:
List[dict]: List of cases
"""
case_state = get_case_state_by_id(cur_id)
return case_state.cases

View File

@ -0,0 +1,302 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from datetime import datetime
from typing import List, Optional, Union
import marshmallow
from app import db
from app.datamgmt.case.case_notes_db import add_note_group, add_note
from app.datamgmt.case.case_tasks_db import add_task
from app.datamgmt.manage.manage_case_classifications_db import get_case_classification_by_name
from app.iris_engine.module_handler.module_handler import call_modules_hook
from app.models import CaseTemplate, Cases, Tags, NotesGroup
from app.models.authorization import User
from app.schema.marshables import CaseSchema, CaseTaskSchema, CaseGroupNoteSchema, CaseAddNoteSchema
def get_case_templates_list() -> List[dict]:
"""Get a list of case templates
Returns:
List[dict]: List of case templates
"""
case_templates = CaseTemplate.query.with_entities(
CaseTemplate.id,
CaseTemplate.name,
CaseTemplate.display_name,
CaseTemplate.description,
CaseTemplate.title_prefix,
CaseTemplate.author,
CaseTemplate.created_at,
CaseTemplate.classification,
CaseTemplate.updated_at,
User.name.label('added_by')
).join(
CaseTemplate.created_by_user
).all()
c_cl = [row._asdict() for row in case_templates]
return c_cl
def get_case_template_by_id(cur_id: int) -> CaseTemplate:
"""Get a case template
Args:
cur_id (int): case template id
Returns:
CaseTemplate: Case template
"""
case_template = CaseTemplate.query.filter_by(id=cur_id).first()
return case_template
def delete_case_template_by_id(case_template_id: int):
"""Delete a case template
Args:
case_template_id (int): case template id
"""
CaseTemplate.query.filter_by(id=case_template_id).delete()
def validate_case_template(data: dict, update: bool = False) -> Optional[str]:
try:
if not update:
# If it's not an update, we check the required fields
if "name" not in data:
return "Name is required."
if "display_name" not in data or not data["display_name"].strip():
data["display_name"] = data["name"]
# We check that name is not empty
if "name" in data and not data["name"].strip():
return "Name cannot be empty."
# We check that author length is not above 128 chars
if "author" in data and len(data["author"]) > 128:
return "Author cannot be longer than 128 characters."
# We check that author length is not above 128 chars
if "author" in data and len(data["author"]) > 128:
return "Author cannot be longer than 128 characters."
# We check that prefix length is not above 32 chars
if "title_prefix" in data and len(data["title_prefix"]) > 32:
return "Prefix cannot be longer than 32 characters."
# We check that tags, if any, are a list of strings
if "tags" in data:
if not isinstance(data["tags"], list):
return "Tags must be a list."
for tag in data["tags"]:
if not isinstance(tag, str):
return "Each tag must be a string."
# We check that tasks, if any, are a list of dictionaries with mandatory keys
if "tasks" in data:
if not isinstance(data["tasks"], list):
return "Tasks must be a list."
for task in data["tasks"]:
if not isinstance(task, dict):
return "Each task must be a dictionary."
if "title" not in task:
return "Each task must have a 'title' field."
if "tags" in task:
if not isinstance(task["tags"], list):
return "Task tags must be a list."
for tag in task["tags"]:
if not isinstance(tag, str):
return "Each tag must be a string."
# We check that note groups, if any, are a list of dictionaries with mandatory keys
if "note_groups" in data:
if not isinstance(data["note_groups"], list):
return "Note groups must be a list."
for note_group in data["note_groups"]:
if not isinstance(note_group, dict):
return "Each note group must be a dictionary."
if "title" not in note_group:
return "Each note group must have a 'title' field."
if "notes" in note_group:
if not isinstance(note_group["notes"], list):
return "Notes must be a list."
for note in note_group["notes"]:
if not isinstance(note, dict):
return "Each note must be a dictionary."
if "title" not in note:
return "Each note must have a 'title' field."
# If all checks succeeded, we return None to indicate everything is has been validated
return None
except Exception as e:
return str(e)
def case_template_pre_modifier(case_schema: CaseSchema, case_template_id: str):
case_template = get_case_template_by_id(int(case_template_id))
if not case_template:
return None
if case_template.title_prefix:
case_schema.name = case_template.title_prefix + " " + case_schema.name[0]
case_classification = get_case_classification_by_name(case_template.classification)
if case_classification:
case_schema.classification_id = case_classification.id
return case_schema
def case_template_populate_tasks(case: Cases, case_template: CaseTemplate):
logs = []
# Update case tasks
for task_template in case_template.tasks:
try:
# validate before saving
task_schema = CaseTaskSchema()
# Remap case task template fields
# Set status to "To Do" which is ID 1
mapped_task_template = {
"task_title": task_template['title'],
"task_description": task_template['description'] if task_template.get('description') else "",
"task_tags": ",".join(tag for tag in task_template["tags"]) if task_template.get('tags') else "",
"task_status_id": 1
}
mapped_task_template = call_modules_hook('on_preload_task_create', data=mapped_task_template, caseid=case.case_id)
task = task_schema.load(mapped_task_template)
assignee_id_list = []
ctask = add_task(task=task,
assignee_id_list=assignee_id_list,
user_id=case.user_id,
caseid=case.case_id
)
ctask = call_modules_hook('on_postload_task_create', data=ctask, caseid=case.case_id)
if not ctask:
logs.append("Unable to create task for internal reasons")
except marshmallow.exceptions.ValidationError as e:
logs.append(e.messages)
return logs
def case_template_populate_notes(case: Cases, note_group_template: dict, ng: NotesGroup):
logs = []
if note_group_template.get("notes"):
for note_template in note_group_template["notes"]:
# validate before saving
note_schema = CaseAddNoteSchema()
mapped_note_template = {
"group_id": ng.group_id,
"note_title": note_template["title"],
"note_content": note_template["content"] if note_template.get("content") else ""
}
mapped_note_template = call_modules_hook('on_preload_note_create', data=mapped_note_template, caseid=case.case_id)
note_schema.verify_group_id(mapped_note_template, caseid=ng.group_case_id)
note = note_schema.load(mapped_note_template)
cnote = add_note(note.get('note_title'),
datetime.utcnow(),
case.user_id,
case.case_id,
note.get('group_id'),
note_content=note.get('note_content'))
cnote = call_modules_hook('on_postload_note_create', data=cnote, caseid=case.case_id)
if not cnote:
logs.append("Unable to add note for internal reasons")
break
return logs
def case_template_populate_note_groups(case: Cases, case_template: CaseTemplate):
logs = []
# Update case tasks
for note_group_template in case_template.note_groups:
try:
# validate before saving
note_group_schema = CaseGroupNoteSchema()
# Remap case task template fields
# Set status to "To Do" which is ID 1
mapped_note_group_template = {
"group_title": note_group_template['title']
}
note_group = note_group_schema.load(mapped_note_group_template)
ng = add_note_group(group_title=note_group.group_title,
caseid=case.case_id,
userid=case.user_id,
creationdate=datetime.utcnow())
if not ng:
logs.append("Unable to add note group for internal reasons")
break
logs = case_template_populate_notes(case, note_group_template, ng)
except marshmallow.exceptions.ValidationError as e:
logs.append(e.messages)
return logs
def case_template_post_modifier(case: Cases, case_template_id: Union[str, int]):
case_template = get_case_template_by_id(int(case_template_id))
logs = []
if not case_template:
logs.append(f"Case template {case_template_id} not found")
return None, logs
# Update summary, we want to append in order not to skip the initial case description
case.description += "\n" + case_template.summary
# Update case tags
for tag_str in case_template.tags:
tag = Tags(tag_title=tag_str)
tag = tag.save()
case.tags.append(tag)
# Update case tasks
logs = case_template_populate_tasks(case, case_template)
if logs:
return case, logs
# Update case note groups
logs = case_template_populate_note_groups(case, case_template)
if logs:
return case, logs
db.session.commit()
return case, logs

View File

@ -0,0 +1,375 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# Copyright (C) 2021 - Airbus CyberSecurity (SAS)
# ir@cyberactionlab.net
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from pathlib import Path
from datetime import datetime
from sqlalchemy import and_
from sqlalchemy.orm import aliased, contains_eager, subqueryload
from app import db
from app.datamgmt.alerts.alerts_db import search_alert_resolution_by_name
from app.datamgmt.case.case_db import get_case_tags
from app.datamgmt.manage.manage_case_classifications_db import get_case_classification_by_id
from app.datamgmt.manage.manage_case_state_db import get_case_state_by_name
from app.datamgmt.states import delete_case_states
from app.models import CaseAssets, CaseClassification, alert_assets_association, CaseStatus, TaskAssignee, TaskComments
from app.models import CaseEventCategory
from app.models import CaseEventsAssets
from app.models import CaseEventsIoc
from app.models import CaseReceivedFile
from app.models import CaseTasks
from app.models import Cases
from app.models import CasesEvent
from app.models import Client
from app.models import DataStoreFile
from app.models import DataStorePath
from app.models import IocAssetLink
from app.models import IocLink
from app.models import Notes
from app.models import NotesGroup
from app.models import NotesGroupLink
from app.models.alerts import Alert, AlertCaseAssociation
from app.models.authorization import CaseAccessLevel
from app.models.authorization import GroupCaseAccess
from app.models.authorization import OrganisationCaseAccess
from app.models.authorization import User
from app.models import UserActivity
from app.models.authorization import UserCaseAccess
from app.models.authorization import UserCaseEffectiveAccess
from app.models.cases import CaseProtagonist, CaseTags, CaseState
from app.schema.marshables import CaseDetailsSchema
def list_cases_id():
res = Cases.query.with_entities(
Cases.case_id
).all()
return [r.case_id for r in res]
def list_cases_dict_unrestricted():
owner_alias = aliased(User)
user_alias = aliased(User)
res = Cases.query.with_entities(
Cases.name.label('case_name'),
Cases.description.label('case_description'),
Client.name.label('client_name'),
Cases.open_date.label('case_open_date'),
Cases.close_date.label('case_close_date'),
Cases.soc_id.label('case_soc_id'),
Cases.user_id.label('opened_by_user_id'),
user_alias.user.label('opened_by'),
Cases.owner_id,
owner_alias.name.label('owner'),
Cases.case_id
).join(
Cases.client
).join(
user_alias, and_(Cases.user_id == user_alias.id)
).join(
owner_alias, and_(Cases.owner_id == owner_alias.id)
).order_by(
Cases.open_date
).all()
data = []
for row in res:
row = row._asdict()
row['case_open_date'] = row['case_open_date'].strftime("%m/%d/%Y")
row['case_close_date'] = row['case_close_date'].strftime("%m/%d/%Y") if row["case_close_date"] else ""
data.append(row)
return data
def list_cases_dict(user_id):
owner_alias = aliased(User)
user_alias = aliased(User)
res = UserCaseEffectiveAccess.query.with_entities(
Cases.name.label('case_name'),
Cases.description.label('case_description'),
Client.name.label('client_name'),
Cases.open_date.label('case_open_date'),
Cases.close_date.label('case_close_date'),
Cases.soc_id.label('case_soc_id'),
Cases.user_id.label('opened_by_user_id'),
user_alias.user.label('opened_by'),
Cases.owner_id,
owner_alias.name.label('owner'),
Cases.case_id,
Cases.case_uuid,
Cases.classification_id,
CaseClassification.name.label('classification'),
Cases.state_id,
CaseState.state_name,
UserCaseEffectiveAccess.access_level
).join(
UserCaseEffectiveAccess.case,
Cases.client,
Cases.user
).outerjoin(
Cases.classification,
Cases.state
).join(
user_alias, and_(Cases.user_id == user_alias.id)
).join(
owner_alias, and_(Cases.owner_id == owner_alias.id)
).filter(
UserCaseEffectiveAccess.user_id == user_id
).order_by(
Cases.open_date
).all()
data = []
for row in res:
if row.access_level & CaseAccessLevel.deny_all.value == CaseAccessLevel.deny_all.value:
continue
row = row._asdict()
row['case_open_date'] = row['case_open_date'].strftime("%m/%d/%Y")
row['case_close_date'] = row['case_close_date'].strftime("%m/%d/%Y") if row["case_close_date"] else ""
data.append(row)
return data
def user_list_cases_view(user_id):
res = UserCaseEffectiveAccess.query.with_entities(
UserCaseEffectiveAccess.case_id
).filter(and_(
UserCaseEffectiveAccess.user_id == user_id,
UserCaseEffectiveAccess.access_level != CaseAccessLevel.deny_all.value
)).all()
return [r.case_id for r in res]
def close_case(case_id):
res = Cases.query.filter(
Cases.case_id == case_id
).first()
if res:
res.close_date = datetime.utcnow()
res.state_id = get_case_state_by_name('Closed').state_id
db.session.commit()
return res
return None
def map_alert_resolution_to_case_status(case_status_id):
if case_status_id == CaseStatus.false_positive.value:
ares = search_alert_resolution_by_name('False Positive', exact_match=True)
elif case_status_id == CaseStatus.true_positive_with_impact.value:
ares = search_alert_resolution_by_name('True Positive With Impact', exact_match=True)
elif case_status_id == CaseStatus.true_positive_without_impact.value:
ares = search_alert_resolution_by_name('True Positive Without Impact', exact_match=True)
else:
ares = search_alert_resolution_by_name('Not Applicable', exact_match=True)
if ares:
return ares.resolution_status_id
return None
def reopen_case(case_id):
res = Cases.query.filter(
Cases.case_id == case_id
).first()
if res:
res.close_date = None
res.state_id = get_case_state_by_name('Open').state_id
db.session.commit()
return res
return None
def get_case_protagonists(case_id):
protagonists = CaseProtagonist.query.with_entities(
CaseProtagonist.role,
CaseProtagonist.name,
CaseProtagonist.contact,
User.name.label('user_name'),
User.user.label('user_login')
).filter(
CaseProtagonist.case_id == case_id
).outerjoin(
CaseProtagonist.user
).all()
return protagonists
def get_case_details_rt(case_id):
case = Cases.query.filter(Cases.case_id == case_id).first()
if case:
owner_alias = aliased(User)
user_alias = aliased(User)
review_alias = aliased(User)
res = db.session.query(Cases, Client, user_alias, owner_alias).with_entities(
Cases.name.label('case_name'),
Cases.description.label('case_description'),
Cases.open_date, Cases.close_date,
Cases.soc_id.label('case_soc_id'),
Cases.case_id,
Cases.case_uuid,
Client.name.label('customer_name'),
Cases.client_id.label('customer_id'),
Cases.user_id.label('open_by_user_id'),
user_alias.user.label('open_by_user'),
Cases.owner_id,
owner_alias.name.label('owner'),
Cases.status_id,
Cases.state_id,
CaseState.state_name,
Cases.custom_attributes,
Cases.modification_history,
Cases.initial_date,
Cases.classification_id,
CaseClassification.name.label('classification'),
Cases.reviewer_id,
review_alias.name.label('reviewer'),
).filter(and_(
Cases.case_id == case_id
)).join(
user_alias, and_(Cases.user_id == user_alias.id)
).outerjoin(
owner_alias, and_(Cases.owner_id == owner_alias.id)
).outerjoin(
review_alias, and_(Cases.reviewer_id == review_alias.id)
).join(
Cases.client,
).outerjoin(
Cases.classification,
Cases.state
).first()
if res is None:
return None
res = res._asdict()
res['case_tags'] = ",".join(get_case_tags(case_id))
res['status_name'] = CaseStatus(res['status_id']).name.replace("_", " ").title()
res['protagonists'] = [r._asdict() for r in get_case_protagonists(case_id)]
else:
res = None
return res
def delete_case(case_id):
if not Cases.query.filter(Cases.case_id == case_id).first():
return False
delete_case_states(caseid=case_id)
UserActivity.query.filter(UserActivity.case_id == case_id).delete()
CaseReceivedFile.query.filter(CaseReceivedFile.case_id == case_id).delete()
IocLink.query.filter(IocLink.case_id == case_id).delete()
CaseTags.query.filter(CaseTags.case_id == case_id).delete()
CaseProtagonist.query.filter(CaseProtagonist.case_id == case_id).delete()
AlertCaseAssociation.query.filter(AlertCaseAssociation.case_id == case_id).delete()
dsf_list = DataStoreFile.query.filter(DataStoreFile.file_case_id == case_id).all()
for dsf_list_item in dsf_list:
fln = Path(dsf_list_item.file_local_name)
if fln.is_file():
fln.unlink(missing_ok=True)
db.session.delete(dsf_list_item)
db.session.commit()
DataStorePath.query.filter(DataStorePath.path_case_id == case_id).delete()
da = CaseAssets.query.with_entities(CaseAssets.asset_id).filter(CaseAssets.case_id == case_id).all()
for asset in da:
IocAssetLink.query.filter(asset.asset_id == asset.asset_id).delete()
CaseEventsAssets.query.filter(CaseEventsAssets.case_id == case_id).delete()
CaseEventsIoc.query.filter(CaseEventsIoc.case_id == case_id).delete()
CaseAssetsAlias = aliased(CaseAssets)
# Query for CaseAssets that are not referenced in alerts and match the case_id
assets_to_delete = db.session.query(CaseAssets).filter(
and_(
CaseAssets.case_id == case_id,
~db.session.query(alert_assets_association).filter(
alert_assets_association.c.asset_id == CaseAssetsAlias.asset_id
).exists()
)
)
# Delete the assets
assets_to_delete.delete(synchronize_session='fetch')
# Get all alerts associated with assets in the case
alerts_to_update = db.session.query(CaseAssets).filter(CaseAssets.case_id == case_id)
# Update case_id for the alerts
alerts_to_update.update({CaseAssets.case_id: None}, synchronize_session='fetch')
db.session.commit()
NotesGroupLink.query.filter(NotesGroupLink.case_id == case_id).delete()
NotesGroup.query.filter(NotesGroup.group_case_id == case_id).delete()
Notes.query.filter(Notes.note_case_id == case_id).delete()
tasks = CaseTasks.query.filter(CaseTasks.task_case_id == case_id).all()
for task in tasks:
TaskAssignee.query.filter(TaskAssignee.task_id == task.id).delete()
CaseTasks.query.filter(CaseTasks.id == task.id).delete()
da = CasesEvent.query.with_entities(CasesEvent.event_id).filter(CasesEvent.case_id == case_id).all()
for event in da:
CaseEventCategory.query.filter(CaseEventCategory.event_id == event.event_id).delete()
CasesEvent.query.filter(CasesEvent.case_id == case_id).delete()
UserCaseAccess.query.filter(UserCaseAccess.case_id == case_id).delete()
UserCaseEffectiveAccess.query.filter(UserCaseEffectiveAccess.case_id == case_id).delete()
GroupCaseAccess.query.filter(GroupCaseAccess.case_id == case_id).delete()
OrganisationCaseAccess.query.filter(OrganisationCaseAccess.case_id == case_id).delete()
Cases.query.filter(Cases.case_id == case_id).delete()
db.session.commit()
return True

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# Copyright (C) 2023 - DFIR-IRIS
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from sqlalchemy import func
from app import db
from app.models.alerts import Severity
def get_severities_list():
"""
Get a list of severities from the database
returns:
list: A list of severities
"""
return db.session.query(Severity).distinct().all()
def get_severity_by_id(status_id: int) -> Severity:
"""
Get a severity from the database by its ID
args:
status_id (int): The ID of the severity to retrieve
returns:
Severity: The severity object
"""
return db.session.query(Severity).filter(Severity.severity_id == status_id).first()
def search_severity_by_name(name: str, exact_match: bool = True) -> Severity:
"""
Search for a severity by its name
args:
name (str): The name of the severity to search for
exact_match (bool): Whether to search for an exact match or not
returns:
Severity: The severity object
"""
if exact_match:
return db.session.query(Severity).filter(func.lower(Severity.severity_name) == name.lower()).all()
return db.session.query(Severity).filter(Severity.severity_name.ilike(f'%{name}%')).all()

View File

@ -0,0 +1,311 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# contact@dfir-iris.org
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from flask_login import current_user
from sqlalchemy import and_
from app import db
from app.datamgmt.case.case_db import get_case
from app.datamgmt.manage.manage_cases_db import list_cases_id
from app.iris_engine.access_control.utils import ac_access_level_mask_from_val_list, ac_ldp_group_removal
from app.iris_engine.access_control.utils import ac_access_level_to_list
from app.iris_engine.access_control.utils import ac_auto_update_user_effective_access
from app.iris_engine.access_control.utils import ac_permission_to_list
from app.models import Cases
from app.models.authorization import Group
from app.models.authorization import GroupCaseAccess
from app.models.authorization import User
from app.models.authorization import UserGroup
from app.schema.marshables import AuthorizationGroupSchema
def get_groups_list():
groups = Group.query.all()
return groups
def get_groups_list_hr_perms():
groups = get_groups_list()
get_membership_list = UserGroup.query.with_entities(
UserGroup.group_id,
User.user,
User.id,
User.name
).join(UserGroup.user).all()
membership_list = {}
for member in get_membership_list:
if member.group_id not in membership_list:
membership_list[member.group_id] = [{
'user': member.user,
'name': member.name,
'id': member.id
}]
else:
membership_list[member.group_id].append({
'user': member.user,
'name': member.name,
'id': member.id
})
groups = AuthorizationGroupSchema().dump(groups, many=True)
for group in groups:
perms = ac_permission_to_list(group['group_permissions'])
group['group_permissions_list'] = perms
group['group_members'] = membership_list.get(group['group_id'], [])
return groups
def get_group(group_id):
group = Group.query.filter(Group.group_id == group_id).first()
return group
def get_group_by_name(group_name):
groups = Group.query.filter(Group.group_name == group_name)
return groups.first()
def get_group_with_members(group_id):
group = get_group(group_id)
if not group:
return None
get_membership_list = UserGroup.query.with_entities(
UserGroup.group_id,
User.user,
User.id,
User.name
).join(
UserGroup.user
).filter(
UserGroup.group_id == group_id
).all()
membership_list = {}
for member in get_membership_list:
if member.group_id not in membership_list:
membership_list[member.group_id] = [{
'user': member.user,
'name': member.name,
'id': member.id
}]
else:
membership_list[member.group_id].append({
'user': member.user,
'name': member.name,
'id': member.id
})
perms = ac_permission_to_list(group.group_permissions)
setattr(group, 'group_permissions_list', perms)
setattr(group, 'group_members', membership_list.get(group.group_id, []))
return group
def get_group_details(group_id):
group = get_group_with_members(group_id)
if not group:
return group
group_accesses = GroupCaseAccess.query.with_entities(
GroupCaseAccess.access_level,
GroupCaseAccess.case_id,
Cases.name.label('case_name')
).join(
GroupCaseAccess.case
).filter(
GroupCaseAccess.group_id == group_id
).all()
group_cases_access = []
for kgroup in group_accesses:
group_cases_access.append({
"access_level": kgroup.access_level,
"access_level_list": ac_access_level_to_list(kgroup.access_level),
"case_id": kgroup.case_id,
"case_name": kgroup.case_name
})
setattr(group, 'group_cases_access', group_cases_access)
return group
def update_group_members(group, members):
if not group:
return None
cur_groups = UserGroup.query.with_entities(
UserGroup.user_id
).filter(UserGroup.group_id == group.group_id).all()
set_cur_groups = set([grp[0] for grp in cur_groups])
set_members = set(int(mber) for mber in members)
users_to_add = set_members - set_cur_groups
users_to_remove = set_cur_groups - set_members
for uid in users_to_add:
user = User.query.filter(User.id == uid).first()
if user:
ug = UserGroup()
ug.group_id = group.group_id
ug.user_id = user.id
db.session.add(ug)
db.session.commit()
ac_auto_update_user_effective_access(uid)
for uid in users_to_remove:
if current_user.id == uid and ac_ldp_group_removal(uid, group.group_id):
continue
UserGroup.query.filter(
and_(UserGroup.group_id == group.group_id,
UserGroup.user_id == uid)
).delete()
db.session.commit()
ac_auto_update_user_effective_access(uid)
return group
def remove_user_from_group(group, member):
if not group:
return None
UserGroup.query.filter(
and_(UserGroup.group_id == group.group_id,
UserGroup.user_id == member.id)
).delete()
db.session.commit()
ac_auto_update_user_effective_access(member.id)
return group
def delete_group(group):
if not group:
return None
UserGroup.query.filter(UserGroup.group_id == group.group_id).delete()
GroupCaseAccess.query.filter(GroupCaseAccess.group_id == group.group_id).delete()
db.session.delete(group)
db.session.commit()
def add_case_access_to_group(group, cases_list, access_level):
if not group:
return None, "Invalid group"
for case_id in cases_list:
case = get_case(case_id)
if not case:
return None, "Invalid case ID"
access_level_mask = ac_access_level_mask_from_val_list([access_level])
ocas = GroupCaseAccess.query.filter(
and_(
GroupCaseAccess.case_id == case_id,
GroupCaseAccess.group_id == group.group_id
)).all()
if ocas:
for oca in ocas:
db.session.delete(oca)
oca = GroupCaseAccess()
oca.group_id = group.group_id
oca.access_level = access_level_mask
oca.case_id = case_id
db.session.add(oca)
db.session.commit()
return group, "Updated"
def add_all_cases_access_to_group(group, access_level):
if not group:
return None, "Invalid group"
for case_id in list_cases_id():
access_level_mask = ac_access_level_mask_from_val_list([access_level])
ocas = GroupCaseAccess.query.filter(
and_(
GroupCaseAccess.case_id == case_id,
GroupCaseAccess.group_id == group.group_id
)).all()
if ocas:
for oca in ocas:
db.session.delete(oca)
oca = GroupCaseAccess()
oca.group_id = group.group_id
oca.access_level = access_level_mask
oca.case_id = case_id
db.session.add(oca)
db.session.commit()
return group, "Updated"
def remove_case_access_from_group(group_id, case_id):
if not group_id or type(group_id) is not int:
return
if not case_id or type(case_id) is not int:
return
GroupCaseAccess.query.filter(
and_(
GroupCaseAccess.case_id == case_id,
GroupCaseAccess.group_id == group_id
)).delete()
db.session.commit()
return
def remove_cases_access_from_group(group_id, cases_list):
if not group_id or type(group_id) is not int:
return False, "Invalid group"
if not cases_list or type(cases_list[0]) is not int:
return False, "Invalid cases list"
GroupCaseAccess.query.filter(
and_(
GroupCaseAccess.case_id.in_(cases_list),
GroupCaseAccess.group_id == group_id
)).delete()
db.session.commit()
return True, "Updated"

View File

@ -0,0 +1,24 @@
from app import db
from app.models import ServerSettings
from app.schema.marshables import ServerSettingsSchema
def get_srv_settings():
return ServerSettings.query.first()
def get_server_settings_as_dict():
srv_settings = ServerSettings.query.first()
if srv_settings:
sc = ServerSettingsSchema()
return sc.dump(srv_settings)
else:
return {}
def get_alembic_revision():
with db.engine.connect() as con:
version_num = con.execute("SELECT version_num FROM alembic_version").first()[0]
return version_num or None

View File

@ -0,0 +1,640 @@
#!/usr/bin/env python3
#
# IRIS Source Code
# Copyright (C) 2021 - Airbus CyberSecurity (SAS)
# ir@cyberactionlab.net
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from flask_login import current_user
from sqlalchemy import and_
from app import bc
from app import db
from app.datamgmt.case.case_db import get_case
from app.iris_engine.access_control.utils import ac_access_level_mask_from_val_list, ac_ldp_group_removal
from app.iris_engine.access_control.utils import ac_access_level_to_list
from app.iris_engine.access_control.utils import ac_auto_update_user_effective_access
from app.iris_engine.access_control.utils import ac_get_detailed_effective_permissions_from_groups
from app.iris_engine.access_control.utils import ac_remove_case_access_from_user
from app.iris_engine.access_control.utils import ac_set_case_access_for_user
from app.models import Cases
from app.models.authorization import CaseAccessLevel
from app.models.authorization import Group
from app.models.authorization import Organisation
from app.models.authorization import User
from app.models.authorization import UserCaseAccess
from app.models.authorization import UserCaseEffectiveAccess
from app.models.authorization import UserGroup
from app.models.authorization import UserOrganisation
def get_user(user_id, id_key: str = 'id'):
user = User.query.filter(getattr(User, id_key) == user_id).first()
return user
def get_active_user_by_login(username):
user = User.query.filter(
User.user == username,
User.active == True
).first()
return user
def list_users_id():
users = User.query.with_entities(User.user_id).all()
return users
def get_user_effective_permissions(user_id):
groups_perms = UserGroup.query.with_entities(
Group.group_permissions,
Group.group_name
).filter(
UserGroup.user_id == user_id
).join(
UserGroup.group
).all()
effective_permissions = ac_get_detailed_effective_permissions_from_groups(groups_perms)
return effective_permissions
def get_user_groups(user_id):
groups = UserGroup.query.with_entities(
Group.group_name,
Group.group_id,
Group.group_uuid
).filter(
UserGroup.user_id == user_id
).join(
UserGroup.group
).all()
output = []
for group in groups:
output.append(group._asdict())
return output
def update_user_groups(user_id, groups):
cur_groups = UserGroup.query.with_entities(
UserGroup.group_id
).filter(UserGroup.user_id == user_id).all()
set_cur_groups = set([grp[0] for grp in cur_groups])
set_new_groups = set(int(grp) for grp in groups)
groups_to_add = set_new_groups - set_cur_groups
groups_to_remove = set_cur_groups - set_new_groups
for group_id in groups_to_add:
user_group = UserGroup()
user_group.user_id = user_id
user_group.group_id = group_id
db.session.add(user_group)
for group_id in groups_to_remove:
if current_user.id == user_id and ac_ldp_group_removal(user_id=user_id, group_id=group_id):
continue
UserGroup.query.filter(
UserGroup.user_id == user_id,
UserGroup.group_id == group_id
).delete()
db.session.commit()
ac_auto_update_user_effective_access(user_id)
def update_user_orgs(user_id, orgs):
cur_orgs = UserOrganisation.query.with_entities(
UserOrganisation.org_id,
UserOrganisation.is_primary_org
).filter(UserOrganisation.user_id == user_id).all()
updated = False
primary_org = 0
for org in cur_orgs:
if org.is_primary_org:
primary_org = org.org_id
if primary_org == 0:
return False, 'User does not have primary organisation. Set one before managing its organisations'
set_cur_orgs = set([org.org_id for org in cur_orgs])
set_new_orgs = set(int(org) for org in orgs)
orgs_to_add = set_new_orgs - set_cur_orgs
orgs_to_remove = set_cur_orgs - set_new_orgs
for org in orgs_to_add:
user_org = UserOrganisation()
user_org.user_id = user_id
user_org.org_id = org
db.session.add(user_org)
updated = True
for org in orgs_to_remove:
if org != primary_org:
UserOrganisation.query.filter(
UserOrganisation.user_id == user_id,
UserOrganisation.org_id == org
).delete()
else:
db.session.rollback()
return False, f'Cannot delete user from primary organisation {org}. Change it before deleting.'
updated = True
db.session.commit()
ac_auto_update_user_effective_access(user_id)
return True, f'Organisations membership updated' if updated else "Nothing changed"
def change_user_primary_org(user_id, old_org_id, new_org_id):
uo_old = UserOrganisation.query.filter(
UserOrganisation.user_id == user_id,
UserOrganisation.org_id == old_org_id
).first()
uo_new = UserOrganisation.query.filter(
UserOrganisation.user_id == user_id,
UserOrganisation.org_id == new_org_id
).first()
if uo_old:
uo_old.is_primary_org = False
if not uo_new:
uo = UserOrganisation()
uo.user_id = user_id
uo.org_id = new_org_id
uo.is_primary_org = True
db.session.add(uo)
else:
uo_new.is_primary_org = True
db.session.commit()
return
def add_user_to_organisation(user_id, org_id, make_primary=False):
org_id = Organisation.query.first().org_id
uo_exists = UserOrganisation.query.filter(
UserOrganisation.user_id == user_id,
UserOrganisation.org_id == org_id
).first()
if uo_exists:
uo_exists.is_primary_org = make_primary
db.session.commit()
return True
# Check if user has a primary org already
prim_org = get_user_primary_org(user_id=user_id)
if make_primary:
prim_org.is_primary_org = False
db.session.commit()
uo = UserOrganisation()
uo.user_id = user_id
uo.org_id = org_id
uo.is_primary_org = prim_org is None
db.session.add(uo)
db.session.commit()
return True
def get_user_primary_org(user_id):
uo = UserOrganisation.query.filter(
and_(UserOrganisation.user_id == user_id,
UserOrganisation.is_primary_org == True)
).all()
if not uo:
return None
uoe = None
index = 0
if len(uo) > 1:
# Fix potential duplication
for u in uo:
if index == 0:
uoe = u
continue
u.is_primary_org = False
db.session.commit()
else:
uoe = uo[0]
return uoe
def add_user_to_group(user_id, group_id):
exists = UserGroup.query.filter(
UserGroup.user_id == user_id,
UserGroup.group_id == group_id
).scalar()
if exists:
return True
ug = UserGroup()
ug.user_id = user_id
ug.group_id = group_id
db.session.add(ug)
db.session.commit()
return True
def get_user_organisations(user_id):
user_org = UserOrganisation.query.with_entities(
Organisation.org_name,
Organisation.org_id,
Organisation.org_uuid,
UserOrganisation.is_primary_org
).filter(
UserOrganisation.user_id == user_id
).join(
UserOrganisation.org
).all()
output = []
for org in user_org:
output.append(org._asdict())
return output
def get_user_cases_access(user_id):
user_accesses = UserCaseAccess.query.with_entities(
UserCaseAccess.access_level,
UserCaseAccess.case_id,
Cases.name.label('case_name')
).join(
UserCaseAccess.case
).filter(
UserCaseAccess.user_id == user_id
).all()
user_cases_access = []
for kuser in user_accesses:
user_cases_access.append({
"access_level": kuser.access_level,
"access_level_list": ac_access_level_to_list(kuser.access_level),
"case_id": kuser.case_id,
"case_name": kuser.case_name
})
return user_cases_access
def get_user_cases_fast(user_id):
user_cases = UserCaseEffectiveAccess.query.with_entities(
UserCaseEffectiveAccess.case_id
).where(
UserCaseEffectiveAccess.user_id == user_id,
UserCaseEffectiveAccess.access_level != CaseAccessLevel.deny_all.value
).all()
return [c.case_id for c in user_cases]
def remove_cases_access_from_user(user_id, cases_list):
if not user_id or type(user_id) is not int:
return False, 'Invalid user id'
if not cases_list or type(cases_list[0]) is not int:
return False, "Invalid cases list"
UserCaseAccess.query.filter(
and_(
UserCaseAccess.case_id.in_(cases_list),
UserCaseAccess.user_id == user_id
)).delete()
db.session.commit()
ac_auto_update_user_effective_access(user_id)
return True, 'Cases access removed'
def remove_case_access_from_user(user_id, case_id):
if not user_id or type(user_id) is not int:
return False, 'Invalid user id'
if not case_id or type(case_id) is not int:
return False, "Invalid case id"
UserCaseAccess.query.filter(
and_(
UserCaseAccess.case_id == case_id,
UserCaseAccess.user_id == user_id
)).delete()
db.session.commit()
ac_remove_case_access_from_user(user_id, case_id)
return True, 'Case access removed'
def set_user_case_access(user_id, case_id, access_level):
if user_id is None or type(user_id) is not int:
return False, 'Invalid user id'
if case_id is None or type(case_id) is not int:
return False, "Invalid case id"
if access_level is None or type(access_level) is not int:
return False, "Invalid access level"
if CaseAccessLevel.has_value(access_level) is False:
return False, "Invalid access level"
uca = UserCaseAccess.query.filter(
UserCaseAccess.user_id == user_id,
UserCaseAccess.case_id == case_id
).all()
if len(uca) > 1:
for u in uca:
db.session.delete(u)
db.session.commit()
uca = None
if not uca:
uca = UserCaseAccess()
uca.user_id = user_id
uca.case_id = case_id
uca.access_level = access_level
db.session.add(uca)
else:
uca[0].access_level = access_level
db.session.commit()
ac_set_case_access_for_user(user_id, case_id, access_level)
return True, 'Case access set to {} for user {}'.format(access_level, user_id)
def get_user_details(user_id, include_api_key=False):
user = User.query.filter(User.id == user_id).first()
if not user:
return None
row = {}
row['user_id'] = user.id
row['user_uuid'] = user.uuid
row['user_name'] = user.name
row['user_login'] = user.user
row['user_email'] = user.email
row['user_active'] = user.active
row['user_is_service_account'] = user.is_service_account
if include_api_key:
row['user_api_key'] = user.api_key
row['user_groups'] = get_user_groups(user_id)
row['user_organisations'] = get_user_organisations(user_id)
row['user_permissions'] = get_user_effective_permissions(user_id)
row['user_cases_access'] = get_user_cases_access(user_id)
upg = get_user_primary_org(user_id)
row['user_primary_organisation_id'] = upg.org_id if upg else 0
return row
def add_case_access_to_user(user, cases_list, access_level):
if not user:
return None, "Invalid user"
for case_id in cases_list:
case = get_case(case_id)
if not case:
return None, "Invalid case ID"
access_level_mask = ac_access_level_mask_from_val_list([access_level])
ocas = UserCaseAccess.query.filter(
and_(
UserCaseAccess.case_id == case_id,
UserCaseAccess.user_id == user.id
)).all()
if ocas:
for oca in ocas:
db.session.delete(oca)
oca = UserCaseAccess()
oca.user_id = user.id
oca.access_level = access_level_mask
oca.case_id = case_id
db.session.add(oca)
db.session.commit()
ac_auto_update_user_effective_access(user.id)
return user, "Updated"
def get_user_by_username(username):
user = User.query.filter(User.user == username).first()
return user
def get_users_list():
users = User.query.all()
output = []
for user in users:
row = {}
row['user_id'] = user.id
row['user_uuid'] = user.uuid
row['user_name'] = user.name
row['user_login'] = user.user
row['user_email'] = user.email
row['user_active'] = user.active
row['user_is_service_account'] = user.is_service_account
output.append(row)
return output
def get_users_list_restricted():
users = User.query.all()
output = []
for user in users:
row = {}
row['user_id'] = user.id
row['user_uuid'] = user.uuid
row['user_name'] = user.name
row['user_login'] = user.user
row['user_active'] = user.active
output.append(row)
return output
def get_users_view_from_user_id(user_id):
organisations = get_user_organisations(user_id)
orgs_id = [uo.get('org_id') for uo in organisations]
users = UserOrganisation.query.with_entities(
User
).filter(and_(
UserOrganisation.org_id.in_(orgs_id),
UserOrganisation.user_id != user_id
)).join(
UserOrganisation.user
).all()
return users
def get_users_id_view_from_user_id(user_id):
organisations = get_user_organisations(user_id)
orgs_id = [uo.get('org_id') for uo in organisations]
users = UserOrganisation.query.with_entities(
User.id
).filter(and_(
UserOrganisation.org_id.in_(orgs_id),
UserOrganisation.user_id != user_id
)).join(
UserOrganisation.user
).all()
users = [u[0] for u in users]
return users
def get_users_list_user_view(user_id):
users = get_users_view_from_user_id(user_id)
output = []
for user in users:
row = {}
row['user_id'] = user.id
row['user_uuid'] = user.uuid
row['user_name'] = user.name
row['user_login'] = user.user
row['user_email'] = user.email
row['user_active'] = user.active
output.append(row)
return output
def get_users_list_restricted_user_view(user_id):
users = get_users_view_from_user_id(user_id)
output = []
for user in users:
row = {}
row['user_id'] = user.id
row['user_uuid'] = user.uuid
row['user_name'] = user.name
row['user_login'] = user.user
row['user_active'] = user.active
output.append(row)
return output
def get_users_list_restricted_from_case(case_id):
users = UserCaseEffectiveAccess.query.with_entities(
User.id.label('user_id'),
User.uuid.label('user_uuid'),
User.name.label('user_name'),
User.user.label('user_login'),
User.active.label('user_active'),
User.email.label('user_email'),
UserCaseEffectiveAccess.access_level.label('user_access_level')
).filter(
UserCaseEffectiveAccess.case_id == case_id
).join(
UserCaseEffectiveAccess.user
).all()
return [u._asdict() for u in users]
def create_user(user_name: str, user_login: str, user_password: str, user_email: str, user_active: bool,
user_external_id: str = None, user_is_service_account: bool = False):
if user_is_service_account is True and (user_password is None or user_password == ''):
pw_hash = None
else:
pw_hash = bc.generate_password_hash(user_password.encode('utf8')).decode('utf8')
user = User(user=user_login, name=user_name, email=user_email, password=pw_hash, active=user_active,
external_id=user_external_id, is_service_account=user_is_service_account)
user.save()
add_user_to_organisation(user.id, org_id=1)
ac_auto_update_user_effective_access(user_id=user.id)
return user
def update_user(user: User, name: str = None, email: str = None, password: str = None):
if password is not None and password != '':
pw_hash = bc.generate_password_hash(password.encode('utf8')).decode('utf8')
user.password = pw_hash
for key, value in [('name', name,), ('email', email,)]:
if value is not None:
setattr(user, key, value)
db.session.commit()
return user
def delete_user(user_id):
UserCaseAccess.query.filter(UserCaseAccess.user_id == user_id).delete()
UserOrganisation.query.filter(UserOrganisation.user_id == user_id).delete()
UserGroup.query.filter(UserGroup.user_id == user_id).delete()
UserCaseEffectiveAccess.query.filter(UserCaseEffectiveAccess.user_id == user_id).delete()
User.query.filter(User.id == user_id).delete()
db.session.commit()
def user_exists(user_name, user_email):
user = User.query.filter_by(user=user_name).first()
user_by_email = User.query.filter_by(email=user_email).first()
return user or user_by_email