TrainCostModelBlock.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python
  2. from CBD.src.CBD import *
  3. class CostFunctionBlock(BaseBlock):
  4. def __init__(self, block_name):
  5. BaseBlock.__init__(self, block_name, ["InVi","InVTrain","InDelta","InXPerson"], ["OutCost"])
  6. self.viChanged = False
  7. self.timeInWhichViChanged = 0.0
  8. self.cummulativeCost = 0.0
  9. def compute(self, curIteration):
  10. displacement_person = self.getInputSignal(curIteration, "InXPerson").value
  11. velocity_train = self.getInputSignal(curIteration, "InVTrain").value
  12. # if abs(displacement_person) > 0.4 or velocity_train<0.0:
  13. # raise StopSimulationException()
  14. currentVi = self.getInputSignal(curIteration, "InVi").value
  15. currentTime = self.getClock().getTime()
  16. lastVi = self.getInputSignal(curIteration-1, "InVi").value
  17. if lastVi != currentVi:
  18. self.viChanged = True
  19. self.timeInWhichViChanged = currentTime
  20. self.attainedVelocity = False
  21. else:
  22. self.viChanged = False
  23. lastVTrain = self.getInputSignal(curIteration-1, "InVTrain").value
  24. currentVTrain = self.getInputSignal(curIteration, "InVTrain").value
  25. if ((lastVTrain-currentVi)*(currentVTrain-currentVi) <= 0):
  26. self.attainedVelocity = True
  27. if (not self.attainedVelocity):
  28. instantCostTime = currentTime - self.timeInWhichViChanged
  29. assert instantCostTime >= 0
  30. delta_t = self.getInputSignal(curIteration, "InDelta").value
  31. self.cummulativeCost = self.cummulativeCost + instantCostTime*delta_t
  32. self.appendToSignal(self.cummulativeCost, name_output="OutCost")
  33. class AboveThresholdBlock(BaseBlock):
  34. def __init__(self, block_name, threshold):
  35. BaseBlock.__init__(self, block_name, ["IN1"], ["OUT1"])
  36. self.threshold = threshold
  37. def compute(self, curIteration):
  38. self.appendToSignal(1.0 if self.getInputSignal(curIteration).value > self.threshold else -1.0)
  39. class StopSimulationBlock(BaseBlock):
  40. def __init__(self, block_name):
  41. BaseBlock.__init__(self, block_name, ["IN1"], [])
  42. def compute(self, curIteration):
  43. inSignalValue = self.getInputSignal(curIteration).value
  44. if inSignalValue > 0.0:
  45. raise StopSimulationException()
  46. class StopSimulationException(Exception):
  47. pass