from collections import deque
from typing import Union
from dataclasses import dataclass
@dataclass
class Transaction:
transaction_id: int
transaction_data: dict
class KVStoreWithTransactions:
def __init__(self):
self.kv_store = {}
self.transactions = {}
self.next_transaction_id = 0
def start_transaction(self) -> int:
new_transaction = Transaction(self.next_transaction_id, {})
self.next_transaction_id += 1
self.transactions[new_transaction.transaction_id] = new_transaction
return new_transaction.transaction_id
def commit_transaction(self, transaction_id: int = - 1) -> bool:
if len(self.transactions) < 1:
return False
if transaction_id == -1:
tid_to_commit = self.get_latest_transaction_id()
else:
tid_to_commit = transaction_id
try:
transaction_to_commit = self.transactions[tid_to_commit]
except KeyError:
print(f"Transaction ID {tid_to_commit} not found")
return False
self.kv_store.update(transaction_to_commit.transaction_data)
self.delete(transaction_to_commit.transaction_id)
return True
def rollback_transaction(self, transaction_id: int = -1) -> bool:
if transaction_id == -1:
tid_to_rollback = self.get_latest_transaction_id()
else:
tid_to_rollback = transaction_id
try:
self.transactions.pop(tid_to_rollback)
return(True)
except KeyError:
print(f"Transaction ID {tid_to_rollback} not found")
return False
def get_latest_transaction_id(self) -> Union[None | int]:
if len(self.transactions) < 1:
return None
else:
return self.next_transaction_id - 1
def set(self, key, value) -> bool:
latest_transaction_id = self.get_latest_transaction_id()
if latest_transaction_id is not None:
latest_transaction = self.transactions[latest_transaction_id]
latest_transaction.transaction_data[key] = value
else:
self.kv_store[key] = value
return(True)
def get(self, key) -> Union[None | str]:
if key not in self.kv_store:
return(None)
if self.kv_store[key] == -1:
return None
return(self.kv_store[key])
def delete(self, key):
'''For deletes we're just going to use tombstone -1 in case of rollbacks'''
self.set(key, -1)
if __name__ == "__main__":
test_store = KVStoreWithTransactions()
assert test_store.get("a") == None
test_store.set("a", 1)
assert(test_store.get("a") == 1)
test_store.delete("a")
assert(test_store.get("a") == None)
test_store.set("b", 2)
t_id = test_store.start_transaction()
assert(t_id == 0)
test_store.set("a", 3)
test_store.set("b", 3)
assert(test_store.get("a") == None)
assert(test_store.get("b") == 2)
committed_first = test_store.commit_transaction()
assert(committed_first == True)
assert(test_store.get("a") == 3)
assert(test_store.get("b") == 3)
second_t_id = test_store.start_transaction()
assert(second_t_id == 1)
test_store.set("a", 4)
test_store.set("b", 4)
third_t_id = test_store.start_transaction()
assert(third_t_id == 2)
test_store.set("a", 5)
test_store.set("b", 5)
committed_third = test_store.commit_transaction(third_t_id)
assert(committed_third == True)
assert(test_store.get("a") == 5)
assert(test_store.get("b") == 5)
committed_second = test_store.commit_transaction(second_t_id)
assert(committed_second == True)
assert(test_store.get("a") == 4)
assert(test_store.get("b") == 4)
fourth_t_id = test_store.start_transaction()
assert(fourth_t_id == 3)
test_store.set("a", 6)
test_store.set("b", 6)
rolled_back = test_store.rollback_transaction(fourth_t_id)
assert(rolled_back == True)
assert(test_store.get("a") == 4)
assert(test_store.get("b") == 4)
bad_id = 100
assert(test_store.commit_transaction(bad_id) == False)
assert(test_store.rollback_transaction(bad_id) == False)
assert(test_store.get_latest_transaction_id() == 3)
assert(test_store.get_latest_transaction_id() == 3)
print("All tests pass :-)")